diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1bbfe3105..f7efcf458 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -343,10 +343,12 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim, dropout) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) + #self.conv_module1 = ConvolutionModule(d_model, + #cnn_module_kernel) + self.nonlin_attention_module = NonlinAttentionModule(d_model) - self.conv_module2 = ConvolutionModule(d_model, + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -444,27 +446,29 @@ class ZipformerEncoderLayer(nn.Module): src = src + src_att # convolution module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: - src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) + if torch.jit.is_scripting() or use_self_attn: + src = src + self.nonlin_attention_module(src, + attn_weights, + head_idx=0) src = src + self.feed_forward2(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite1(src, attn_weights, attn_weights_idx=0) + src = src + self.squeeze_excite1(src, attn_weights, head_idx=1) if torch.jit.is_scripting() or use_self_attn: self_attn_output2 = self.self_attn.forward2(src, attn_weights) src = src + self_attn_output2 if torch.jit.is_scripting() or random.random() > dynamic_dropout: - src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward3(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite2(src, attn_weights, attn_weights_idx=1) + src = src + self.squeeze_excite2(src, attn_weights, head_idx=2) src = self.norm_final(self.balancer(src)) @@ -1495,19 +1499,20 @@ class ModifiedSEModule(nn.Module): def forward(self, x: Tensor, attn_weights: Tensor, - attn_weights_idx: int): + head_idx: int): """ Args: x: a Tensor of shape (T, N, C) -attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head. -attn_weights_idx: indicates which head to choose from attn_weights +attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the head indexed + `attn_weights_index` + head_idx: indicates which head to choose from attn_weights Returns: a Tensor of shape (T, N, C) """ (T, N, d_model) = x.shape num_heads = attn_weights.shape[0] // N attn_weights = attn_weights.reshape(N, num_heads, T, T) - attn_weights = attn_weights[:,attn_weights_idx] # (N, T, T) + attn_weights = attn_weights[:,head_idx] # (N, T, T) bottleneck = self.to_bottleneck_proj(x) # (T, N, C) bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim) @@ -1552,6 +1557,80 @@ class FeedforwardModule(nn.Module): return x +class NonlinAttentionModule(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in palce of actual convolution. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, channels: int, + ) -> None: + super().__init__() + + # to_scale and to_value are analogous to pointwise_conv1 in ConvolutionModule + # we make them separate because we need an extra degree of freedom for the + # scale, as the attention weights are constrained to sum to one so cannot + # provide the degree of freedom for the scale of the features before + # self.activation(). + self.to_scale = nn.Linear(channels, channels, bias=True) + self.to_value = nn.Linear(channels, channels, bias=True) + + + # deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule + self.deriv_balancer = ActivationBalancer( + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.out_proj = ScaledLinear(channels, channels, + bias=True, + initial_scale=0.05) + + def forward(self, + x: Tensor, + attn_weights: Tensor, + head_idx: int, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (T, N, C), i.e. (time, batch, channels) + attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head. + head_idx: indicates which head to choose from attn_weights + Returns: + a Tensor of shape (T, N, C) + """ + + s = self.to_scale(x) + v = self.to_value(x) + if self.training and random.random() < 0.02: + # prevent the inputs to the sigmoid from getting very large (this is + # unlikely to happen in this particular module, so giving this path + # a very small probability). + s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04) + + # GLU mechanism + x = s.sigmoid() * v + + (T, N, d_model) = x.shape + num_heads = attn_weights.shape[0] // N + attn_weights = attn_weights.reshape(N, num_heads, T, T) + attn_weights = attn_weights[:,head_idx] # (N, T, T) + x = x.transpose(0, 1) # (N, T, C) + x = torch.bmm(attn_weights, x) + x = self.deriv_balancer(x) + x = x.transpose(0, 1) # (T, N, C) + x = self.activation(x) + x = self.out_proj(x) + + return x + + class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py