diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 34b9d00f0..f90bb9379 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -269,7 +269,11 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) + + # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -925,11 +929,16 @@ class ConvolutionModule(nn.Module): initial_scale=0.5, ) - def forward(self, x: Tensor) -> Tensor: + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. Returns: Tensor: Output tensor (#time, batch, channels). @@ -944,6 +953,9 @@ class ConvolutionModule(nn.Module): x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + # 1D Depthwise Conv x = self.depthwise_conv(x)