Test, and fix, TransformerDecoderLayerRelPos

This commit is contained in:
Daniel Povey 2021-08-23 17:39:37 +08:00
parent 556fae586f
commit 7856ab89fc
2 changed files with 37 additions and 13 deletions

View File

@ -249,7 +249,7 @@ class TransformerDecoderRelPos(nn.Module):
""" """
for mod in self.layers: for mod in self.layers:
x = mod(x, pos_emb, memory, x_mask=x_mask, x = mod(x, pos_emb, memory, attn_mask=attn_mask,
key_padding_mask=key_padding_mask) key_padding_mask=key_padding_mask)
if self.norm is not None: if self.norm is not None:
@ -294,7 +294,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
activation: str = "relu", activation: str = "relu",
) -> None: ) -> None:
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayerRelPos, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model # Implementation of Feedforward model
@ -315,14 +315,14 @@ class TransformerDecoderLayerRelPos(nn.Module):
def __setstate__(self, state): def __setstate__(self, state):
if "activation" not in state: if "activation" not in state:
state["activation"] = nn.functional.relu state["activation"] = nn.functional.relu
super(TransformerDecoderLayer, self).__setstate__(state) super(TransformerDecoderLayerRelPos, self).__setstate__(state)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
pos_emb: torch.Tensor, pos_emb: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,
x_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer. """Pass the inputs (and mask) through the decoder layer.
@ -330,13 +330,13 @@ class TransformerDecoderLayerRelPos(nn.Module):
Args: Args:
x x
The input embedding, to be added to by the forward function, of shape (T, N, C). The input embedding, to be added to by the forward function, of shape (T, N, C).
Attention within x will be left-to-right only (causal), thanks to x_mask. Attention within x will be left-to-right only (causal), thanks to attn_mask.
pos_emb: pos_emb:
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
containing the relative positional encoding. containing the relative positional encoding.
memory: memory:
the sequence from the last layer of the encoder (required). Shape = (T, N, C) the sequence from the last layer of the encoder (required). Shape = (T, N, C)
x_mask: attn_mask:
the mask for the x, to enforce causal (left to right) attention (optional). the mask for the x, to enforce causal (left to right) attention (optional).
Shape == (T, T); may be bool or float. The first T pertains to the output, Shape == (T, T); may be bool or float. The first T pertains to the output,
the second T to the input. the second T to the input.
@ -351,15 +351,17 @@ class TransformerDecoderLayerRelPos(nn.Module):
residual = x residual = x
x = self.norm1(x) x = self.norm1(x)
self_attn = self.self_attn(x, x, x, self_attn = self.self_attn(x, x, x,
pos_emb=pos_emb,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False, need_weights=False,
attn_mask=x_mask, attn_mask=attn_mask,
)[0] )[0]
x = residual + self.dropout1(self_attn) x = residual + self.dropout1(self_attn)
residual = x residual = x
x = self.norm2(x) x = self.norm2(x)
src_attn = self.src_attn(x, memory, memory, src_attn = self.src_attn(x, memory, memory,
pos_emb=pos_emb,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False, need_weights=False,
)[0] )[0]

View File

@ -5,6 +5,7 @@
import torch import torch
from conformer import ( from conformer import (
TransformerDecoderRelPos, TransformerDecoderRelPos,
TransformerDecoderLayerRelPos,
MaskedLmConformer, MaskedLmConformer,
MaskedLmConformerEncoder, MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer, MaskedLmConformerEncoderLayer,
@ -28,9 +29,9 @@ def test_rel_position_multihead_attention():
x = torch.randn(N, T, C) x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C) #pos_emb = torch.randn(1, 2*T-1, C)
x, pos_enc = pos_emb_module(x) x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C) x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
def test_masked_lm_conformer_encoder_layer(): def test_masked_lm_conformer_encoder_layer():
@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer():
x = torch.randn(N, T, C) x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x) x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C) x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
def test_masked_lm_conformer_encoder(): def test_masked_lm_conformer_encoder():
@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder():
x = torch.randn(N, T, C) x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x) x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C) x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask) y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
def test_transformer_decoder_layer_rel_pos():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)