diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 9f0db2e81..35d700119 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -757,8 +757,10 @@ class MaskedConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - + x = x.transpose(1, 2) # (batch, time, channel) + x = self.norm(x) # LayerNorm requires channel be last dim. + x = x.transpose(1, 2) # (batch, channel, time) + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) # (time, batch, channel) @@ -807,7 +809,7 @@ class MaskedLmConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, ) -> None: - super(ConformerEncoderLayer, self).__init__() + super(MaskedLmConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 8aaae4277..3fdd8a222 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -6,6 +6,7 @@ import torch from conformer import ( TransformerDecoderRelPos, MaskedLmConformer, + MaskedLmConformerEncoderLayer, RelPositionMultiheadAttention, RelPositionalEncoding, generate_square_subsequent_mask, @@ -27,11 +28,28 @@ def test_rel_position_multihead_attention(): x = torch.randn(N, T, C) #pos_emb = torch.randn(1, 2*T-1, C) x, pos_enc = pos_emb_module(x) - print("pos_enc.shape=", pos_enc.shape) x = x.transpose(0, 1) # (T, N, C) attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) +def test_masked_lm_conformer_encoder_layer(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads) + + + x = torch.randn(N, T, C) + x, pos_enc = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) + + def test_transformer(): return num_features = 40