Test, and fix, TransformerDecoderLayerRelPos

This commit is contained in:
Daniel Povey 2021-08-23 17:39:37 +08:00
parent 556fae586f
commit 7856ab89fc
2 changed files with 37 additions and 13 deletions

View File

@ -249,7 +249,7 @@ class TransformerDecoderRelPos(nn.Module):
"""
for mod in self.layers:
x = mod(x, pos_emb, memory, x_mask=x_mask,
x = mod(x, pos_emb, memory, attn_mask=attn_mask,
key_padding_mask=key_padding_mask)
if self.norm is not None:
@ -294,7 +294,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
dropout: float = 0.1,
activation: str = "relu",
) -> None:
super(TransformerDecoderLayer, self).__init__()
super(TransformerDecoderLayerRelPos, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
@ -315,14 +315,14 @@ class TransformerDecoderLayerRelPos(nn.Module):
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerDecoderLayer, self).__setstate__(state)
super(TransformerDecoderLayerRelPos, self).__setstate__(state)
def forward(
self,
x: torch.Tensor,
pos_emb: torch.Tensor,
memory: torch.Tensor,
x_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.
@ -330,13 +330,13 @@ class TransformerDecoderLayerRelPos(nn.Module):
Args:
x
The input embedding, to be added to by the forward function, of shape (T, N, C).
Attention within x will be left-to-right only (causal), thanks to x_mask.
Attention within x will be left-to-right only (causal), thanks to attn_mask.
pos_emb:
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
containing the relative positional encoding.
memory:
the sequence from the last layer of the encoder (required). Shape = (T, N, C)
x_mask:
attn_mask:
the mask for the x, to enforce causal (left to right) attention (optional).
Shape == (T, T); may be bool or float. The first T pertains to the output,
the second T to the input.
@ -351,15 +351,17 @@ class TransformerDecoderLayerRelPos(nn.Module):
residual = x
x = self.norm1(x)
self_attn = self.self_attn(x, x, x,
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
need_weights=False,
attn_mask=x_mask,
attn_mask=attn_mask,
)[0]
x = residual + self.dropout1(self_attn)
residual = x
x = self.norm2(x)
src_attn = self.src_attn(x, memory, memory,
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]

View File

@ -5,6 +5,7 @@
import torch
from conformer import (
TransformerDecoderRelPos,
TransformerDecoderLayerRelPos,
MaskedLmConformer,
MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer,
@ -28,9 +29,9 @@ def test_rel_position_multihead_attention():
x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C)
x, pos_enc = pos_emb_module(x)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
def test_masked_lm_conformer_encoder_layer():
@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer():
x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
def test_masked_lm_conformer_encoder():
@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder():
x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
def test_transformer_decoder_layer_rel_pos():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)