mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Test, and fix, TransformerDecoderRelPos
This commit is contained in:
parent
7856ab89fc
commit
5fecd24664
@ -166,24 +166,24 @@ class MaskedLmConformer(nn.Module):
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(T, memory.device)
|
||||
|
||||
src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
|
||||
src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
|
||||
src = memory + self.src_linear(src) # (T, N, C)
|
||||
x = memory + self.src_linear(x) # (T, N, C)
|
||||
|
||||
# This is a little confusing, how "tgt" is set to src. "src" is the
|
||||
# symbol sequence without masking but with padding and randomization.
|
||||
# "tgt" is like "src" but shifted by one.
|
||||
pred = self.decoder(
|
||||
tgt=src,
|
||||
x,
|
||||
pos_emb,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=key_padding_mask,
|
||||
memory_key_padding_mask=key_padding_mask,
|
||||
) # (T, N, C)
|
||||
|
||||
pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred = self.decoder_output_layer(pred) # (N, T, C)
|
||||
|
||||
# nll: negative log-likelihood
|
||||
@ -247,15 +247,14 @@ class TransformerDecoderRelPos(nn.Module):
|
||||
a torch.Tensor with dtype=bool and shape (N, T): true for masked
|
||||
positions after the ends of sequences.
|
||||
"""
|
||||
|
||||
for mod in self.layers:
|
||||
x = mod(x, pos_emb, memory, attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
x = self.norm(x)
|
||||
|
||||
return output
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayerRelPos(nn.Module):
|
||||
|
@ -74,7 +74,6 @@ def test_masked_lm_conformer_encoder():
|
||||
|
||||
|
||||
def test_transformer_decoder_layer_rel_pos():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
@ -94,6 +93,26 @@ def test_transformer_decoder_layer_rel_pos():
|
||||
|
||||
|
||||
|
||||
def test_transformer_decoder_rel_pos():
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
|
||||
decoder_norm = torch.nn.LayerNorm(embed_dim)
|
||||
decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
attn_mask = generate_square_subsequent_mask(T)
|
||||
memory = torch.randn(T, N, C)
|
||||
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_transformer():
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user