mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Test, and fix, TransformerDecoderLayerRelPos
This commit is contained in:
parent
556fae586f
commit
7856ab89fc
@ -249,7 +249,7 @@ class TransformerDecoderRelPos(nn.Module):
|
||||
"""
|
||||
|
||||
for mod in self.layers:
|
||||
x = mod(x, pos_emb, memory, x_mask=x_mask,
|
||||
x = mod(x, pos_emb, memory, attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask)
|
||||
|
||||
if self.norm is not None:
|
||||
@ -294,7 +294,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
) -> None:
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
super(TransformerDecoderLayerRelPos, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
# Implementation of Feedforward model
|
||||
@ -315,14 +315,14 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
def __setstate__(self, state):
|
||||
if "activation" not in state:
|
||||
state["activation"] = nn.functional.relu
|
||||
super(TransformerDecoderLayer, self).__setstate__(state)
|
||||
super(TransformerDecoderLayerRelPos, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
x_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
@ -330,13 +330,13 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
Args:
|
||||
x
|
||||
The input embedding, to be added to by the forward function, of shape (T, N, C).
|
||||
Attention within x will be left-to-right only (causal), thanks to x_mask.
|
||||
Attention within x will be left-to-right only (causal), thanks to attn_mask.
|
||||
pos_emb:
|
||||
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
|
||||
containing the relative positional encoding.
|
||||
memory:
|
||||
the sequence from the last layer of the encoder (required). Shape = (T, N, C)
|
||||
x_mask:
|
||||
attn_mask:
|
||||
the mask for the x, to enforce causal (left to right) attention (optional).
|
||||
Shape == (T, T); may be bool or float. The first T pertains to the output,
|
||||
the second T to the input.
|
||||
@ -351,15 +351,17 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
self_attn = self.self_attn(x, x, x,
|
||||
pos_emb=pos_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=x_mask,
|
||||
attn_mask=attn_mask,
|
||||
)[0]
|
||||
x = residual + self.dropout1(self_attn)
|
||||
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
src_attn = self.src_attn(x, memory, memory,
|
||||
pos_emb=pos_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,
|
||||
)[0]
|
||||
|
@ -5,6 +5,7 @@
|
||||
import torch
|
||||
from conformer import (
|
||||
TransformerDecoderRelPos,
|
||||
TransformerDecoderLayerRelPos,
|
||||
MaskedLmConformer,
|
||||
MaskedLmConformerEncoder,
|
||||
MaskedLmConformerEncoderLayer,
|
||||
@ -28,9 +29,9 @@ def test_rel_position_multihead_attention():
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
|
||||
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder_layer():
|
||||
@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer():
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder():
|
||||
@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder():
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_transformer_decoder_layer_rel_pos():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
attn_mask = generate_square_subsequent_mask(T)
|
||||
memory = torch.randn(T, N, C)
|
||||
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user