From 0d94783e760e27f056101d4b65dc1ebd86c83eda Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Nov 2022 15:16:59 +0800 Subject: [PATCH] Instead of a pooling operation, use the first bottleneck_dim dimensions of the preceding self_attn.forward2 as the input to the squeeze-excite module. --- .../pruned_transducer_stateless7/zipformer.py | 37 +++++++------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 76e616c33..0e7d33453 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -449,7 +449,8 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) if torch.jit.is_scripting() or use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) + self_attn_output2 = self.self_attn.forward2(src, attn_weights) + src = src + self_attn_output2 if torch.jit.is_scripting() or random.random() > dynamic_dropout: src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) @@ -457,9 +458,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) # pooling module - if torch.jit.is_scripting() or random.random() > dynamic_dropout: + if torch.jit.is_scripting() or use_self_attn: src = src + self.squeeze_excite(src, - key_padding_mask=src_key_padding_mask) + self_attn_output2[...,:self.squeeze_excite.bottleneck_dim]) + src = self.norm_final(self.balancer(src)) @@ -1451,9 +1453,7 @@ class ModifiedSEModule(nn.Module): d_model: int, bottleneck_dim: int = 16): super().__init__() - self.squeeze_proj = nn.Linear(d_model, bottleneck_dim, - bias=False) - + self.bottleneck_dim = bottleneck_dim self.in_proj = nn.Linear(d_model, d_model, bias=False) @@ -1488,33 +1488,22 @@ class ModifiedSEModule(nn.Module): def forward(self, x: Tensor, - key_padding_mask): + bottleneck: Tensor): """ Args: x: a Tensor of shape (T, N, C) - key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked - positions. + bottleneck: a Tensor of shape (1, N, bottleneck_dim) or (T, N, bottleneck_dim) that has + undergone some form of aggregation over time, e.g. attention. Returns: a Tensor of shape (1, N, C) """ - if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) - pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) - # now pooling_mask: (T, N, 1) - else: - num_frames = x.shape[0] - pooling_mask = 1.0 / num_frames - - squeezed = (x * pooling_mask).sum(dim=0, keepdim=True) - squeezed = self.squeeze_proj(squeezed) - squeezed = self.balancer(squeezed) - squeezed = self.activation(squeezed) - squeezed = self.from_bottleneck_proj(squeezed) + bottleneck = self.balancer(bottleneck) + bottleneck = self.activation(bottleneck) + scales = self.from_bottleneck_proj(bottleneck) x = self.in_proj(x) - x = x * squeezed + x = x * scales return self.out_whiten(self.out_proj(x))