Add more testing; fix issue about channel dim of LayerNorm.

This commit is contained in:
Daniel Povey 2021-08-23 17:18:00 +08:00
parent e0b04ba54f
commit 2fbe3b78fd
2 changed files with 24 additions and 4 deletions

View File

@ -757,8 +757,10 @@ class MaskedConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = x.transpose(1, 2) # (batch, time, channel)
x = self.norm(x) # LayerNorm requires channel be last dim.
x = x.transpose(1, 2) # (batch, channel, time)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) # (time, batch, channel)
@ -807,7 +809,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
) -> None:
super(ConformerEncoderLayer, self).__init__()
super(MaskedLmConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)

View File

@ -6,6 +6,7 @@ import torch
from conformer import (
TransformerDecoderRelPos,
MaskedLmConformer,
MaskedLmConformerEncoderLayer,
RelPositionMultiheadAttention,
RelPositionalEncoding,
generate_square_subsequent_mask,
@ -27,11 +28,28 @@ def test_rel_position_multihead_attention():
x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C)
x, pos_enc = pos_emb_module(x)
print("pos_enc.shape=", pos_enc.shape)
x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
def test_masked_lm_conformer_encoder_layer():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
def test_transformer():
return
num_features = 40