Test, and fix, TransformerDecoderRelPos

This commit is contained in:
Daniel Povey 2021-08-23 17:48:00 +08:00
parent 7856ab89fc
commit 5fecd24664
2 changed files with 28 additions and 10 deletions

View File

@ -166,24 +166,24 @@ class MaskedLmConformer(nn.Module):
tgt_mask = generate_square_subsequent_mask(T, memory.device)
src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
src = memory + self.src_linear(src) # (T, N, C)
x = memory + self.src_linear(x) # (T, N, C)
# This is a little confusing, how "tgt" is set to src. "src" is the
# symbol sequence without masking but with padding and randomization.
# "tgt" is like "src" but shifted by one.
pred = self.decoder(
tgt=src,
x,
pos_emb,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=key_padding_mask,
memory_key_padding_mask=key_padding_mask,
) # (T, N, C)
pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred = self.decoder_output_layer(pred) # (N, T, C)
# nll: negative log-likelihood
@ -247,15 +247,14 @@ class TransformerDecoderRelPos(nn.Module):
a torch.Tensor with dtype=bool and shape (N, T): true for masked
positions after the ends of sequences.
"""
for mod in self.layers:
x = mod(x, pos_emb, memory, attn_mask=attn_mask,
key_padding_mask=key_padding_mask)
if self.norm is not None:
output = self.norm(output)
x = self.norm(x)
return output
return x
class TransformerDecoderLayerRelPos(nn.Module):

View File

@ -74,7 +74,6 @@ def test_masked_lm_conformer_encoder():
def test_transformer_decoder_layer_rel_pos():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
@ -94,6 +93,26 @@ def test_transformer_decoder_layer_rel_pos():
def test_transformer_decoder_rel_pos():
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
decoder_norm = torch.nn.LayerNorm(embed_dim)
decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
def test_transformer():
return