diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 6207dab84..1963056cc 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -52,8 +52,8 @@ class MaskedLmConformer(nn.Module): # self.embed is the embedding used for both the encoder and decoder. self.embed_scale = d_model ** 0.5 self.embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model, - _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) + num_embeddings=self.num_classes, embedding_dim=d_model, + _weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale) ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -69,9 +69,8 @@ class MaskedLmConformer(nn.Module): norm=nn.LayerNorm(d_model)) if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes - decoder_layer = TransformerDecoderLayerRelPos( + decoder_layer = RelPosTransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, @@ -82,14 +81,14 @@ class MaskedLmConformer(nn.Module): self.src_linear = torch.nn.Linear(d_model, d_model) decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoderRelPos( + self.decoder = RelPosTransformerDecoder( decoder_layer=decoder_layer, num_layers=num_decoder_layers, norm=decoder_norm, ) self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class + d_model, self.num_classes ) @@ -112,8 +111,8 @@ class MaskedLmConformer(nn.Module): Returns: - Returns (encoded, pos_emb), where: - `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + Returns (memory, pos_emb), where: + `memory` is a Tensor containing the encoded data; it is of shape (N, T, C) where C is the embedding_dim. `pos_emb` is a Tensor containing the relative positional encoding, of shape (1, 2*T-1, C) @@ -164,7 +163,7 @@ class MaskedLmConformer(nn.Module): """ (T, N, C) = memory.shape - tgt_mask = generate_square_subsequent_mask(T, memory.device) + attn_mask = generate_square_subsequent_mask(T, memory.device) x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -178,18 +177,17 @@ class MaskedLmConformer(nn.Module): x, pos_emb, memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=key_padding_mask, - memory_key_padding_mask=key_padding_mask, - ) # (T, N, C) + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + # (T, N, C) pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred = self.decoder_output_layer(pred) # (N, T, C) # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( - pred.view(-1, self.decoder_num_class), - tgt_symbols.view(-1), + pred.view(-1, self.num_classes), + tgt_symbols.reshape(-1), reduction="none", ) nll = nll.view(N, T) @@ -198,19 +196,19 @@ class MaskedLmConformer(nn.Module): -class TransformerDecoderRelPos(nn.Module): - r"""TransformerDecoderRelPos is a stack of N decoder layers. +class RelPosTransformerDecoder(nn.Module): + r"""RelPosTransformerDecoder is a stack of N decoder layers. This is modified from nn.TransformerDecoder to support relative positional encoding. Args: - decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + decoder_layer: an instance of the RelPosTransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) + >>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.RelPosTransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> pos_enc = torch.rand() @@ -219,7 +217,7 @@ class TransformerDecoderRelPos(nn.Module): __constants__ = ['norm'] def __init__(self, decoder_layer, num_layers, norm=None): - super(TransformerDecoderRelPos, self).__init__() + super(RelPosTransformerDecoder, self).__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm @@ -257,7 +255,7 @@ class TransformerDecoderRelPos(nn.Module): return x -class TransformerDecoderLayerRelPos(nn.Module): +class RelPosTransformerDecoderLayer(nn.Module): """ Modified from torch.nn.TransformerDecoderLayer. Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; @@ -278,7 +276,7 @@ class TransformerDecoderLayerRelPos(nn.Module): gelu (default=relu). Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> pos_emb = torch.rand(1, 20*2+1, 512) @@ -293,7 +291,7 @@ class TransformerDecoderLayerRelPos(nn.Module): dropout: float = 0.1, activation: str = "relu", ) -> None: - super(TransformerDecoderLayerRelPos, self).__init__() + super(RelPosTransformerDecoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model @@ -314,7 +312,7 @@ class TransformerDecoderLayerRelPos(nn.Module): def __setstate__(self, state): if "activation" not in state: state["activation"] = nn.functional.relu - super(TransformerDecoderLayerRelPos, self).__setstate__(state) + super(RelPosTransformerDecoderLayer, self).__setstate__(state) def forward( self, diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 75f603d9d..cad2c4d8f 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -297,7 +297,7 @@ def collate_fn(sentences: List[List[int]], Will be reflected in the returned tgt_weights tensor. Returns a tuple (masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, + tgt_symbols, src_key_padding_mask, tgt_weights), all with 2 axes and the same shape: (num_sent, seq_len). Their dtypes will be, respectively, @@ -315,7 +315,7 @@ def collate_fn(sentences: List[List[int]], tgt_symbols: The original sentences, with eos_symbol appended, and then padded with blank to the same length as masked_symbols and src_symbols. - src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to + src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to account for all the sentence lengths not being identical (makes each sentence's processing independent of seq_len). Tensor of Bool of shape (num_sent, seq_len), with True @@ -368,17 +368,17 @@ def collate_fn(sentences: List[List[int]], src_symbols = torch.tensor(srcs, dtype=torch.int64) masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64) tgt_symbols = torch.tensor(tgts, dtype=torch.int64) - src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool) + src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool) tgt_weights = torch.tensor(weights, dtype=torch.float) - attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist() + attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist() while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists. attn_mask_sum.pop() if len(attn_mask_sum) < seq_len: seq_len = len(attn_mask_sum) (src_symbols, masked_src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], - tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len], + tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], + tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len], tgt_weights[:,:seq_len]) if randomize_proportion > 0.0: @@ -409,9 +409,9 @@ def collate_fn(sentences: List[List[int]], check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym, unmasked_weight, masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) + tgt_symbols, src_key_padding_mask, tgt_weights) return (masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) + tgt_symbols, src_key_padding_mask, tgt_weights) @@ -421,20 +421,20 @@ def check_collated_tensors(sentences: List[List[int]], blank_sym: int, unmasked_weight: float, masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, + tgt_symbols, src_key_padding_mask, tgt_weights): """ This function checks the output of collate_fn, consider it test code. Please see the documentation of collate_fn to understand the args. """ - for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights: + for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights: assert t.shape == masked_src_symbols.shape tot_positions = src_symbols.numel() - masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = ( + masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = ( masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(), - src_attn_mask.tolist(), tgt_weights.tolist()) + src_key_padding_mask.tolist(), tgt_weights.tolist()) assert len(sentences) == len(masked_src_symbols) tot_masked_positions = 0 @@ -451,7 +451,7 @@ def check_collated_tensors(sentences: List[List[int]], if sentences[i] != reconstructed_sent: print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}") (masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i], - tgt_symbols[i], src_attn_mask[i], tgt_weights[i]) + tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i]) assert src[0] == masked_src[0] == bos_sym for j in range(len(masked_src)): diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 99acfdcd0..45b50a6ea 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -3,9 +3,10 @@ # python3 -m pytest test_conformer.py import torch +import dataset # from . from conformer import ( - TransformerDecoderRelPos, - TransformerDecoderLayerRelPos, + RelPosTransformerDecoder, + RelPosTransformerDecoderLayer, MaskedLmConformer, MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, @@ -80,7 +81,7 @@ def test_transformer_decoder_layer_rel_pos(): N = 4 C = 256 pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) - decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads) x = torch.randn(N, T, C) @@ -100,10 +101,9 @@ def test_transformer_decoder_rel_pos(): N = 4 C = 256 pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) - decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads) decoder_norm = torch.nn.LayerNorm(embed_dim) - decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm) - + decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm) x = torch.randn(N, T, C) x, pos_emb = pos_emb_module(x) @@ -114,18 +114,30 @@ def test_transformer_decoder_rel_pos(): y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) -def test_transformer(): - return - num_features = 40 +def test_masked_lm_conformer(): + num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) + d_model = 256 + + model = MaskedLmConformer(num_classes,d_model) + N = 31 - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + (masked_src_symbols, src_symbols, + tgt_symbols, src_key_padding_mask, + tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2, + blank_sym=0) + + # test forward() of MaskedLmConformer + memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) + nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols, + src_key_padding_mask) + print("nll = ", nll) + loss = (nll * tgt_weights).sum() + print("loss = ", loss) + def test_generate_square_subsequent_mask():