mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Test, and fix, TransformerDecoderLayerRelPos
This commit is contained in:
parent
556fae586f
commit
7856ab89fc
@ -249,7 +249,7 @@ class TransformerDecoderRelPos(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
for mod in self.layers:
|
for mod in self.layers:
|
||||||
x = mod(x, pos_emb, memory, x_mask=x_mask,
|
x = mod(x, pos_emb, memory, attn_mask=attn_mask,
|
||||||
key_padding_mask=key_padding_mask)
|
key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
@ -294,7 +294,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
activation: str = "relu",
|
activation: str = "relu",
|
||||||
) -> None:
|
) -> None:
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayerRelPos, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||||
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||||
# Implementation of Feedforward model
|
# Implementation of Feedforward model
|
||||||
@ -315,14 +315,14 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
if "activation" not in state:
|
if "activation" not in state:
|
||||||
state["activation"] = nn.functional.relu
|
state["activation"] = nn.functional.relu
|
||||||
super(TransformerDecoderLayer, self).__setstate__(state)
|
super(TransformerDecoderLayerRelPos, self).__setstate__(state)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
pos_emb: torch.Tensor,
|
pos_emb: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
x_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
key_padding_mask: Optional[torch.Tensor] = None,
|
key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
@ -330,13 +330,13 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x
|
x
|
||||||
The input embedding, to be added to by the forward function, of shape (T, N, C).
|
The input embedding, to be added to by the forward function, of shape (T, N, C).
|
||||||
Attention within x will be left-to-right only (causal), thanks to x_mask.
|
Attention within x will be left-to-right only (causal), thanks to attn_mask.
|
||||||
pos_emb:
|
pos_emb:
|
||||||
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
|
A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels,
|
||||||
containing the relative positional encoding.
|
containing the relative positional encoding.
|
||||||
memory:
|
memory:
|
||||||
the sequence from the last layer of the encoder (required). Shape = (T, N, C)
|
the sequence from the last layer of the encoder (required). Shape = (T, N, C)
|
||||||
x_mask:
|
attn_mask:
|
||||||
the mask for the x, to enforce causal (left to right) attention (optional).
|
the mask for the x, to enforce causal (left to right) attention (optional).
|
||||||
Shape == (T, T); may be bool or float. The first T pertains to the output,
|
Shape == (T, T); may be bool or float. The first T pertains to the output,
|
||||||
the second T to the input.
|
the second T to the input.
|
||||||
@ -351,15 +351,17 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
|||||||
residual = x
|
residual = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
self_attn = self.self_attn(x, x, x,
|
self_attn = self.self_attn(x, x, x,
|
||||||
|
pos_emb=pos_emb,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=False,
|
need_weights=False,
|
||||||
attn_mask=x_mask,
|
attn_mask=attn_mask,
|
||||||
)[0]
|
)[0]
|
||||||
x = residual + self.dropout1(self_attn)
|
x = residual + self.dropout1(self_attn)
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
x = self.norm2(x)
|
x = self.norm2(x)
|
||||||
src_attn = self.src_attn(x, memory, memory,
|
src_attn = self.src_attn(x, memory, memory,
|
||||||
|
pos_emb=pos_emb,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=False,
|
need_weights=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from conformer import (
|
from conformer import (
|
||||||
TransformerDecoderRelPos,
|
TransformerDecoderRelPos,
|
||||||
|
TransformerDecoderLayerRelPos,
|
||||||
MaskedLmConformer,
|
MaskedLmConformer,
|
||||||
MaskedLmConformerEncoder,
|
MaskedLmConformerEncoder,
|
||||||
MaskedLmConformerEncoderLayer,
|
MaskedLmConformerEncoderLayer,
|
||||||
@ -28,9 +29,9 @@ def test_rel_position_multihead_attention():
|
|||||||
|
|
||||||
x = torch.randn(N, T, C)
|
x = torch.randn(N, T, C)
|
||||||
#pos_emb = torch.randn(1, 2*T-1, C)
|
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||||
x, pos_enc = pos_emb_module(x)
|
x, pos_emb = pos_emb_module(x)
|
||||||
x = x.transpose(0, 1) # (T, N, C)
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
|
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
|
||||||
|
|
||||||
|
|
||||||
def test_masked_lm_conformer_encoder_layer():
|
def test_masked_lm_conformer_encoder_layer():
|
||||||
@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer():
|
|||||||
|
|
||||||
|
|
||||||
x = torch.randn(N, T, C)
|
x = torch.randn(N, T, C)
|
||||||
x, pos_enc = pos_emb_module(x)
|
x, pos_emb = pos_emb_module(x)
|
||||||
x = x.transpose(0, 1) # (T, N, C)
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
def test_masked_lm_conformer_encoder():
|
def test_masked_lm_conformer_encoder():
|
||||||
@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder():
|
|||||||
|
|
||||||
|
|
||||||
x = torch.randn(N, T, C)
|
x = torch.randn(N, T, C)
|
||||||
x, pos_enc = pos_emb_module(x)
|
x, pos_emb = pos_emb_module(x)
|
||||||
x = x.transpose(0, 1) # (T, N, C)
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
|
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_decoder_layer_rel_pos():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
attn_mask = generate_square_subsequent_mask(T)
|
||||||
|
memory = torch.randn(T, N, C)
|
||||||
|
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user