mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
Get tests to work for MaskedLmConformer
This commit is contained in:
parent
5fecd24664
commit
26b5b5ba46
@ -52,8 +52,8 @@ class MaskedLmConformer(nn.Module):
|
||||
# self.embed is the embedding used for both the encoder and decoder.
|
||||
self.embed_scale = d_model ** 0.5
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model,
|
||||
_weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale)
|
||||
num_embeddings=self.num_classes, embedding_dim=d_model,
|
||||
_weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale)
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
@ -69,9 +69,8 @@ class MaskedLmConformer(nn.Module):
|
||||
norm=nn.LayerNorm(d_model))
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
self.decoder_num_class = self.num_classes
|
||||
|
||||
decoder_layer = TransformerDecoderLayerRelPos(
|
||||
decoder_layer = RelPosTransformerDecoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
@ -82,14 +81,14 @@ class MaskedLmConformer(nn.Module):
|
||||
self.src_linear = torch.nn.Linear(d_model, d_model)
|
||||
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoderRelPos(
|
||||
self.decoder = RelPosTransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=num_decoder_layers,
|
||||
norm=decoder_norm,
|
||||
)
|
||||
|
||||
self.decoder_output_layer = torch.nn.Linear(
|
||||
d_model, self.decoder_num_class
|
||||
d_model, self.num_classes
|
||||
)
|
||||
|
||||
|
||||
@ -112,8 +111,8 @@ class MaskedLmConformer(nn.Module):
|
||||
|
||||
|
||||
Returns:
|
||||
Returns (encoded, pos_emb), where:
|
||||
`encoded` is a Tensor containing the encoded data; it is of shape (N, T, C)
|
||||
Returns (memory, pos_emb), where:
|
||||
`memory` is a Tensor containing the encoded data; it is of shape (N, T, C)
|
||||
where C is the embedding_dim.
|
||||
`pos_emb` is a Tensor containing the relative positional encoding, of
|
||||
shape (1, 2*T-1, C)
|
||||
@ -164,7 +163,7 @@ class MaskedLmConformer(nn.Module):
|
||||
"""
|
||||
(T, N, C) = memory.shape
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(T, memory.device)
|
||||
attn_mask = generate_square_subsequent_mask(T, memory.device)
|
||||
|
||||
x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
@ -178,18 +177,17 @@ class MaskedLmConformer(nn.Module):
|
||||
x,
|
||||
pos_emb,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=key_padding_mask,
|
||||
memory_key_padding_mask=key_padding_mask,
|
||||
) # (T, N, C)
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask)
|
||||
# (T, N, C)
|
||||
|
||||
pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred = self.decoder_output_layer(pred) # (N, T, C)
|
||||
|
||||
# nll: negative log-likelihood
|
||||
nll = torch.nn.functional.cross_entropy(
|
||||
pred.view(-1, self.decoder_num_class),
|
||||
tgt_symbols.view(-1),
|
||||
pred.view(-1, self.num_classes),
|
||||
tgt_symbols.reshape(-1),
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(N, T)
|
||||
@ -198,19 +196,19 @@ class MaskedLmConformer(nn.Module):
|
||||
|
||||
|
||||
|
||||
class TransformerDecoderRelPos(nn.Module):
|
||||
r"""TransformerDecoderRelPos is a stack of N decoder layers.
|
||||
class RelPosTransformerDecoder(nn.Module):
|
||||
r"""RelPosTransformerDecoder is a stack of N decoder layers.
|
||||
This is modified from nn.TransformerDecoder to support relative positional
|
||||
encoding.
|
||||
|
||||
Args:
|
||||
decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required).
|
||||
decoder_layer: an instance of the RelPosTransformerDecoderLayer() class (required).
|
||||
num_layers: the number of sub-decoder-layers in the decoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
Examples::
|
||||
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
|
||||
>>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6)
|
||||
>>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8)
|
||||
>>> transformer_decoder = nn.RelPosTransformerDecoder(decoder_layer, num_layers=6)
|
||||
>>> memory = torch.rand(10, 32, 512)
|
||||
>>> tgt = torch.rand(20, 32, 512)
|
||||
>>> pos_enc = torch.rand()
|
||||
@ -219,7 +217,7 @@ class TransformerDecoderRelPos(nn.Module):
|
||||
__constants__ = ['norm']
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None):
|
||||
super(TransformerDecoderRelPos, self).__init__()
|
||||
super(RelPosTransformerDecoder, self).__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
@ -257,7 +255,7 @@ class TransformerDecoderRelPos(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayerRelPos(nn.Module):
|
||||
class RelPosTransformerDecoderLayer(nn.Module):
|
||||
"""
|
||||
Modified from torch.nn.TransformerDecoderLayer.
|
||||
Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block;
|
||||
@ -278,7 +276,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
gelu (default=relu).
|
||||
|
||||
Examples::
|
||||
>>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8)
|
||||
>>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8)
|
||||
>>> memory = torch.rand(10, 32, 512)
|
||||
>>> tgt = torch.rand(20, 32, 512)
|
||||
>>> pos_emb = torch.rand(1, 20*2+1, 512)
|
||||
@ -293,7 +291,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
) -> None:
|
||||
super(TransformerDecoderLayerRelPos, self).__init__()
|
||||
super(RelPosTransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
# Implementation of Feedforward model
|
||||
@ -314,7 +312,7 @@ class TransformerDecoderLayerRelPos(nn.Module):
|
||||
def __setstate__(self, state):
|
||||
if "activation" not in state:
|
||||
state["activation"] = nn.functional.relu
|
||||
super(TransformerDecoderLayerRelPos, self).__setstate__(state)
|
||||
super(RelPosTransformerDecoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -297,7 +297,7 @@ def collate_fn(sentences: List[List[int]],
|
||||
Will be reflected in the returned tgt_weights tensor.
|
||||
|
||||
Returns a tuple (masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_attn_mask,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights),
|
||||
all with 2 axes and the same shape: (num_sent, seq_len).
|
||||
Their dtypes will be, respectively,
|
||||
@ -315,7 +315,7 @@ def collate_fn(sentences: List[List[int]],
|
||||
tgt_symbols: The original sentences, with eos_symbol appended, and then
|
||||
padded with blank to the same length as masked_symbols and
|
||||
src_symbols.
|
||||
src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to
|
||||
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
|
||||
account for all the sentence lengths not being identical
|
||||
(makes each sentence's processing independent of seq_len).
|
||||
Tensor of Bool of shape (num_sent, seq_len), with True
|
||||
@ -368,17 +368,17 @@ def collate_fn(sentences: List[List[int]],
|
||||
src_symbols = torch.tensor(srcs, dtype=torch.int64)
|
||||
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
|
||||
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
|
||||
src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool)
|
||||
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
|
||||
tgt_weights = torch.tensor(weights, dtype=torch.float)
|
||||
|
||||
attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist()
|
||||
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
|
||||
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
|
||||
attn_mask_sum.pop()
|
||||
if len(attn_mask_sum) < seq_len:
|
||||
seq_len = len(attn_mask_sum)
|
||||
(src_symbols, masked_src_symbols,
|
||||
tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
|
||||
tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len],
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
|
||||
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
|
||||
tgt_weights[:,:seq_len])
|
||||
|
||||
if randomize_proportion > 0.0:
|
||||
@ -409,9 +409,9 @@ def collate_fn(sentences: List[List[int]],
|
||||
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
|
||||
unmasked_weight,
|
||||
masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_attn_mask, tgt_weights)
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||
return (masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_attn_mask, tgt_weights)
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||
|
||||
|
||||
|
||||
@ -421,20 +421,20 @@ def check_collated_tensors(sentences: List[List[int]],
|
||||
blank_sym: int,
|
||||
unmasked_weight: float,
|
||||
masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_attn_mask,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights):
|
||||
"""
|
||||
This function checks the output of collate_fn, consider it test code. Please see
|
||||
the documentation of collate_fn to understand the args.
|
||||
"""
|
||||
for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights:
|
||||
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
|
||||
assert t.shape == masked_src_symbols.shape
|
||||
|
||||
tot_positions = src_symbols.numel()
|
||||
|
||||
masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = (
|
||||
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
|
||||
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
|
||||
src_attn_mask.tolist(), tgt_weights.tolist())
|
||||
src_key_padding_mask.tolist(), tgt_weights.tolist())
|
||||
assert len(sentences) == len(masked_src_symbols)
|
||||
|
||||
tot_masked_positions = 0
|
||||
@ -451,7 +451,7 @@ def check_collated_tensors(sentences: List[List[int]],
|
||||
if sentences[i] != reconstructed_sent:
|
||||
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
|
||||
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
|
||||
tgt_symbols[i], src_attn_mask[i], tgt_weights[i])
|
||||
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
|
||||
|
||||
assert src[0] == masked_src[0] == bos_sym
|
||||
for j in range(len(masked_src)):
|
||||
|
@ -3,9 +3,10 @@
|
||||
# python3 -m pytest test_conformer.py
|
||||
|
||||
import torch
|
||||
import dataset # from .
|
||||
from conformer import (
|
||||
TransformerDecoderRelPos,
|
||||
TransformerDecoderLayerRelPos,
|
||||
RelPosTransformerDecoder,
|
||||
RelPosTransformerDecoderLayer,
|
||||
MaskedLmConformer,
|
||||
MaskedLmConformerEncoder,
|
||||
MaskedLmConformerEncoderLayer,
|
||||
@ -80,7 +81,7 @@ def test_transformer_decoder_layer_rel_pos():
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
|
||||
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
@ -100,10 +101,9 @@ def test_transformer_decoder_rel_pos():
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads)
|
||||
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||
decoder_norm = torch.nn.LayerNorm(embed_dim)
|
||||
decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||
|
||||
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
@ -114,18 +114,30 @@ def test_transformer_decoder_rel_pos():
|
||||
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_transformer():
|
||||
return
|
||||
num_features = 40
|
||||
def test_masked_lm_conformer():
|
||||
|
||||
num_classes = 87
|
||||
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||
d_model = 256
|
||||
|
||||
model = MaskedLmConformer(num_classes,d_model)
|
||||
|
||||
|
||||
N = 31
|
||||
|
||||
for T in range(7, 30):
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||
|
||||
(masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
|
||||
blank_sym=0)
|
||||
|
||||
# test forward() of MaskedLmConformer
|
||||
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
|
||||
src_key_padding_mask)
|
||||
print("nll = ", nll)
|
||||
loss = (nll * tgt_weights).sum()
|
||||
print("loss = ", loss)
|
||||
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
|
Loading…
x
Reference in New Issue
Block a user