diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1b5cfb8f0..1bbfe3105 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -350,7 +350,8 @@ class ZipformerEncoderLayer(nn.Module): cnn_module_kernel) - self.squeeze_excite = ModifiedSEModule(d_model) + self.squeeze_excite1 = ModifiedSEModule(d_model) + self.squeeze_excite2 = ModifiedSEModule(d_model) self.norm_final = BasicNorm(d_model) @@ -448,6 +449,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) + # pooling module + if torch.jit.is_scripting() or use_self_attn: + src = src + self.squeeze_excite1(src, attn_weights, attn_weights_idx=0) + if torch.jit.is_scripting() or use_self_attn: self_attn_output2 = self.self_attn.forward2(src, attn_weights) src = src + self_attn_output2 @@ -459,8 +464,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.squeeze_excite(src, - attn_weights) + src = src + self.squeeze_excite2(src, attn_weights, attn_weights_idx=1) src = self.norm_final(self.balancer(src)) @@ -1490,19 +1494,20 @@ class ModifiedSEModule(nn.Module): def forward(self, x: Tensor, - attn_weights: Tensor): + attn_weights: Tensor, + attn_weights_idx: int): """ Args: x: a Tensor of shape (T, N, C) attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head. - +attn_weights_idx: indicates which head to choose from attn_weights Returns: a Tensor of shape (T, N, C) """ (T, N, d_model) = x.shape num_heads = attn_weights.shape[0] // N attn_weights = attn_weights.reshape(N, num_heads, T, T) - attn_weights = attn_weights[:,0] # (N, T, T) + attn_weights = attn_weights[:,attn_weights_idx] # (N, T, T) bottleneck = self.to_bottleneck_proj(x) # (T, N, C) bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim)