Get tests to work for MaskedLmConformer

This commit is contained in:
Daniel Povey 2021-08-23 19:05:31 +08:00
parent 5fecd24664
commit 26b5b5ba46
3 changed files with 62 additions and 52 deletions

View File

@ -52,8 +52,8 @@ class MaskedLmConformer(nn.Module):
# self.embed is the embedding used for both the encoder and decoder.
self.embed_scale = d_model ** 0.5
self.embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model,
_weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale)
num_embeddings=self.num_classes, embedding_dim=d_model,
_weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale)
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -69,9 +69,8 @@ class MaskedLmConformer(nn.Module):
norm=nn.LayerNorm(d_model))
if num_decoder_layers > 0:
self.decoder_num_class = self.num_classes
decoder_layer = TransformerDecoderLayerRelPos(
decoder_layer = RelPosTransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
@ -82,14 +81,14 @@ class MaskedLmConformer(nn.Module):
self.src_linear = torch.nn.Linear(d_model, d_model)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoderRelPos(
self.decoder = RelPosTransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_decoder_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
d_model, self.num_classes
)
@ -112,8 +111,8 @@ class MaskedLmConformer(nn.Module):
Returns:
Returns (encoded, pos_emb), where:
`encoded` is a Tensor containing the encoded data; it is of shape (N, T, C)
Returns (memory, pos_emb), where:
`memory` is a Tensor containing the encoded data; it is of shape (N, T, C)
where C is the embedding_dim.
`pos_emb` is a Tensor containing the relative positional encoding, of
shape (1, 2*T-1, C)
@ -164,7 +163,7 @@ class MaskedLmConformer(nn.Module):
"""
(T, N, C) = memory.shape
tgt_mask = generate_square_subsequent_mask(T, memory.device)
attn_mask = generate_square_subsequent_mask(T, memory.device)
x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -178,18 +177,17 @@ class MaskedLmConformer(nn.Module):
x,
pos_emb,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=key_padding_mask,
memory_key_padding_mask=key_padding_mask,
) # (T, N, C)
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)
# (T, N, C)
pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred = self.decoder_output_layer(pred) # (N, T, C)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred.view(-1, self.decoder_num_class),
tgt_symbols.view(-1),
pred.view(-1, self.num_classes),
tgt_symbols.reshape(-1),
reduction="none",
)
nll = nll.view(N, T)
@ -198,19 +196,19 @@ class MaskedLmConformer(nn.Module):
class TransformerDecoderRelPos(nn.Module):
r"""TransformerDecoderRelPos is a stack of N decoder layers.
class RelPosTransformerDecoder(nn.Module):
r"""RelPosTransformerDecoder is a stack of N decoder layers.
This is modified from nn.TransformerDecoder to support relative positional
encoding.
Args:
decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required).
decoder_layer: an instance of the RelPosTransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6)
>>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.RelPosTransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> pos_enc = torch.rand()
@ -219,7 +217,7 @@ class TransformerDecoderRelPos(nn.Module):
__constants__ = ['norm']
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoderRelPos, self).__init__()
super(RelPosTransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
@ -257,7 +255,7 @@ class TransformerDecoderRelPos(nn.Module):
return x
class TransformerDecoderLayerRelPos(nn.Module):
class RelPosTransformerDecoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerDecoderLayer.
Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block;
@ -278,7 +276,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
gelu (default=relu).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
>>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> pos_emb = torch.rand(1, 20*2+1, 512)
@ -293,7 +291,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
dropout: float = 0.1,
activation: str = "relu",
) -> None:
super(TransformerDecoderLayerRelPos, self).__init__()
super(RelPosTransformerDecoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
@ -314,7 +312,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerDecoderLayerRelPos, self).__setstate__(state)
super(RelPosTransformerDecoderLayer, self).__setstate__(state)
def forward(
self,

View File

@ -297,7 +297,7 @@ def collate_fn(sentences: List[List[int]],
Will be reflected in the returned tgt_weights tensor.
Returns a tuple (masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask,
tgt_symbols, src_key_padding_mask,
tgt_weights),
all with 2 axes and the same shape: (num_sent, seq_len).
Their dtypes will be, respectively,
@ -315,7 +315,7 @@ def collate_fn(sentences: List[List[int]],
tgt_symbols: The original sentences, with eos_symbol appended, and then
padded with blank to the same length as masked_symbols and
src_symbols.
src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
account for all the sentence lengths not being identical
(makes each sentence's processing independent of seq_len).
Tensor of Bool of shape (num_sent, seq_len), with True
@ -368,17 +368,17 @@ def collate_fn(sentences: List[List[int]],
src_symbols = torch.tensor(srcs, dtype=torch.int64)
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool)
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
tgt_weights = torch.tensor(weights, dtype=torch.float)
attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist()
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
attn_mask_sum.pop()
if len(attn_mask_sum) < seq_len:
seq_len = len(attn_mask_sum)
(src_symbols, masked_src_symbols,
tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len],
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
tgt_weights[:,:seq_len])
if randomize_proportion > 0.0:
@ -409,9 +409,9 @@ def collate_fn(sentences: List[List[int]],
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
unmasked_weight,
masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask, tgt_weights)
tgt_symbols, src_key_padding_mask, tgt_weights)
return (masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask, tgt_weights)
tgt_symbols, src_key_padding_mask, tgt_weights)
@ -421,20 +421,20 @@ def check_collated_tensors(sentences: List[List[int]],
blank_sym: int,
unmasked_weight: float,
masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask,
tgt_symbols, src_key_padding_mask,
tgt_weights):
"""
This function checks the output of collate_fn, consider it test code. Please see
the documentation of collate_fn to understand the args.
"""
for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights:
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
assert t.shape == masked_src_symbols.shape
tot_positions = src_symbols.numel()
masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = (
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
src_attn_mask.tolist(), tgt_weights.tolist())
src_key_padding_mask.tolist(), tgt_weights.tolist())
assert len(sentences) == len(masked_src_symbols)
tot_masked_positions = 0
@ -451,7 +451,7 @@ def check_collated_tensors(sentences: List[List[int]],
if sentences[i] != reconstructed_sent:
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
tgt_symbols[i], src_attn_mask[i], tgt_weights[i])
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
assert src[0] == masked_src[0] == bos_sym
for j in range(len(masked_src)):

View File

@ -3,9 +3,10 @@
# python3 -m pytest test_conformer.py
import torch
import dataset # from .
from conformer import (
TransformerDecoderRelPos,
TransformerDecoderLayerRelPos,
RelPosTransformerDecoder,
RelPosTransformerDecoderLayer,
MaskedLmConformer,
MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer,
@ -80,7 +81,7 @@ def test_transformer_decoder_layer_rel_pos():
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
x = torch.randn(N, T, C)
@ -100,10 +101,9 @@ def test_transformer_decoder_rel_pos():
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
decoder_norm = torch.nn.LayerNorm(embed_dim)
decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm)
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
@ -114,18 +114,30 @@ def test_transformer_decoder_rel_pos():
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
def test_transformer():
return
num_features = 40
def test_masked_lm_conformer():
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
d_model = 256
model = MaskedLmConformer(num_classes,d_model)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
(masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask,
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
blank_sym=0)
# test forward() of MaskedLmConformer
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
src_key_padding_mask)
print("nll = ", nll)
loss = (nll * tgt_weights).sum()
print("loss = ", loss)
def test_generate_square_subsequent_mask():