diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 163f47543..e158e88d5 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -249,7 +249,7 @@ class TransformerDecoderRelPos(nn.Module): """ for mod in self.layers: - x = mod(x, pos_emb, memory, x_mask=x_mask, + x = mod(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) if self.norm is not None: @@ -294,7 +294,7 @@ class TransformerDecoderLayerRelPos(nn.Module): dropout: float = 0.1, activation: str = "relu", ) -> None: - super(TransformerDecoderLayer, self).__init__() + super(TransformerDecoderLayerRelPos, self).__init__() self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model @@ -315,14 +315,14 @@ class TransformerDecoderLayerRelPos(nn.Module): def __setstate__(self, state): if "activation" not in state: state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) + super(TransformerDecoderLayerRelPos, self).__setstate__(state) def forward( self, x: torch.Tensor, pos_emb: torch.Tensor, memory: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. @@ -330,13 +330,13 @@ class TransformerDecoderLayerRelPos(nn.Module): Args: x The input embedding, to be added to by the forward function, of shape (T, N, C). - Attention within x will be left-to-right only (causal), thanks to x_mask. + Attention within x will be left-to-right only (causal), thanks to attn_mask. pos_emb: A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, containing the relative positional encoding. memory: the sequence from the last layer of the encoder (required). Shape = (T, N, C) - x_mask: + attn_mask: the mask for the x, to enforce causal (left to right) attention (optional). Shape == (T, T); may be bool or float. The first T pertains to the output, the second T to the input. @@ -351,15 +351,17 @@ class TransformerDecoderLayerRelPos(nn.Module): residual = x x = self.norm1(x) self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, key_padding_mask=key_padding_mask, need_weights=False, - attn_mask=x_mask, + attn_mask=attn_mask, )[0] x = residual + self.dropout1(self_attn) residual = x x = self.norm2(x) src_attn = self.src_attn(x, memory, memory, + pos_emb=pos_emb, key_padding_mask=key_padding_mask, need_weights=False, )[0] diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 8c2b2efa4..106b84738 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -5,6 +5,7 @@ import torch from conformer import ( TransformerDecoderRelPos, + TransformerDecoderLayerRelPos, MaskedLmConformer, MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, @@ -28,9 +29,9 @@ def test_rel_position_multihead_attention(): x = torch.randn(N, T, C) #pos_emb = torch.randn(1, 2*T-1, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) - attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) + attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb) def test_masked_lm_conformer_encoder_layer(): @@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer(): x = torch.randn(N, T, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) - y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) + y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask) def test_masked_lm_conformer_encoder(): @@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder(): x = torch.randn(N, T, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) - y = encoder(x, pos_enc, key_padding_mask=key_padding_mask) + y = encoder(x, pos_emb, key_padding_mask=key_padding_mask) + + +def test_transformer_decoder_layer_rel_pos(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + + + x = torch.randn(N, T, C) + x, pos_emb = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + attn_mask = generate_square_subsequent_mask(T) + memory = torch.randn(T, N, C) + y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) +