diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index e158e88d5..6207dab84 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -166,24 +166,24 @@ class MaskedLmConformer(nn.Module): tgt_mask = generate_square_subsequent_mask(T, memory.device) - src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) - src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - src = memory + self.src_linear(src) # (T, N, C) + x = memory + self.src_linear(x) # (T, N, C) # This is a little confusing, how "tgt" is set to src. "src" is the # symbol sequence without masking but with padding and randomization. # "tgt" is like "src" but shifted by one. pred = self.decoder( - tgt=src, + x, + pos_emb, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=key_padding_mask, memory_key_padding_mask=key_padding_mask, ) # (T, N, C) - pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred = self.decoder_output_layer(pred) # (N, T, C) # nll: negative log-likelihood @@ -247,15 +247,14 @@ class TransformerDecoderRelPos(nn.Module): a torch.Tensor with dtype=bool and shape (N, T): true for masked positions after the ends of sequences. """ - for mod in self.layers: x = mod(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) if self.norm is not None: - output = self.norm(output) + x = self.norm(x) - return output + return x class TransformerDecoderLayerRelPos(nn.Module): diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 106b84738..99acfdcd0 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -74,7 +74,6 @@ def test_masked_lm_conformer_encoder(): def test_transformer_decoder_layer_rel_pos(): - # Also tests RelPositionalEncoding embed_dim = 256 num_heads = 4 T = 25 @@ -94,6 +93,26 @@ def test_transformer_decoder_layer_rel_pos(): +def test_transformer_decoder_rel_pos(): + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_norm = torch.nn.LayerNorm(embed_dim) + decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm) + + + x = torch.randn(N, T, C) + x, pos_emb = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + attn_mask = generate_square_subsequent_mask(T) + memory = torch.randn(T, N, C) + y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + def test_transformer(): return