diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0e7d33453..1b5cfb8f0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -460,7 +460,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: src = src + self.squeeze_excite(src, - self_attn_output2[...,:self.squeeze_excite.bottleneck_dim]) + attn_weights) src = self.norm_final(self.balancer(src)) @@ -1458,6 +1458,10 @@ class ModifiedSEModule(nn.Module): self.in_proj = nn.Linear(d_model, d_model, bias=False) + self.to_bottleneck_proj = ScaledLinear(d_model, + bottleneck_dim, + bias=False) + # Caution: this cannot work correctly with an extremeley small batch size, e.g. if # we were training with a single very long audio sequence, or just 2 or 3 sequences # at a time. We make max_factor small to reduce the harm this could cause @@ -1484,21 +1488,30 @@ class ModifiedSEModule(nn.Module): grad_scale=0.01) - - def forward(self, x: Tensor, - bottleneck: Tensor): + attn_weights: Tensor): """ Args: x: a Tensor of shape (T, N, C) - bottleneck: a Tensor of shape (1, N, bottleneck_dim) or (T, N, bottleneck_dim) that has - undergone some form of aggregation over time, e.g. attention. +attn_weights: a Tensor of shape (N * num_heads, seq_len, seq_len), we will only use the 1st head. + Returns: - a Tensor of shape (1, N, C) + a Tensor of shape (T, N, C) """ + (T, N, d_model) = x.shape + num_heads = attn_weights.shape[0] // N + attn_weights = attn_weights.reshape(N, num_heads, T, T) + attn_weights = attn_weights[:,0] # (N, T, T) + + bottleneck = self.to_bottleneck_proj(x) # (T, N, C) + bottleneck = bottleneck.transpose(0, 1) # (N, T, bottleneck_dim) + + # (N, T, T) x (N, T, bottleneck_dim) -> (N, T, bottleneck_dim) + bottleneck = torch.bmm(attn_weights, bottleneck) bottleneck = self.balancer(bottleneck) bottleneck = self.activation(bottleneck) + bottleneck = bottleneck.transpose(0, 1) # (T, N, bottleneck_dim) scales = self.from_bottleneck_proj(bottleneck)