Add more testing; fix issue about channel dim of LayerNorm.

This commit is contained in:
Daniel Povey 2021-08-23 17:18:00 +08:00
parent e0b04ba54f
commit 2fbe3b78fd
2 changed files with 24 additions and 4 deletions

View File

@ -757,8 +757,10 @@ class MaskedConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) x = x.transpose(1, 2) # (batch, time, channel)
x = self.norm(x) # LayerNorm requires channel be last dim.
x = x.transpose(1, 2) # (batch, channel, time)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) # (time, batch, channel) return x.permute(2, 0, 1) # (time, batch, channel)
@ -807,7 +809,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(MaskedLmConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0 d_model, nhead, dropout=0.0
) )

View File

@ -6,6 +6,7 @@ import torch
from conformer import ( from conformer import (
TransformerDecoderRelPos, TransformerDecoderRelPos,
MaskedLmConformer, MaskedLmConformer,
MaskedLmConformerEncoderLayer,
RelPositionMultiheadAttention, RelPositionMultiheadAttention,
RelPositionalEncoding, RelPositionalEncoding,
generate_square_subsequent_mask, generate_square_subsequent_mask,
@ -27,11 +28,28 @@ def test_rel_position_multihead_attention():
x = torch.randn(N, T, C) x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C) #pos_emb = torch.randn(1, 2*T-1, C)
x, pos_enc = pos_emb_module(x) x, pos_enc = pos_emb_module(x)
print("pos_enc.shape=", pos_enc.shape)
x = x.transpose(0, 1) # (T, N, C) x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
def test_masked_lm_conformer_encoder_layer():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
def test_transformer(): def test_transformer():
return return
num_features = 40 num_features = 40