mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add more testing; fix issue about channel dim of LayerNorm.
This commit is contained in:
parent
e0b04ba54f
commit
2fbe3b78fd
@ -757,8 +757,10 @@ class MaskedConvolutionModule(nn.Module):
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = x.transpose(1, 2) # (batch, time, channel)
|
||||
x = self.norm(x) # LayerNorm requires channel be last dim.
|
||||
x = x.transpose(1, 2) # (batch, channel, time)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1) # (time, batch, channel)
|
||||
@ -807,7 +809,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
super(MaskedLmConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from conformer import (
|
||||
TransformerDecoderRelPos,
|
||||
MaskedLmConformer,
|
||||
MaskedLmConformerEncoderLayer,
|
||||
RelPositionMultiheadAttention,
|
||||
RelPositionalEncoding,
|
||||
generate_square_subsequent_mask,
|
||||
@ -27,11 +28,28 @@ def test_rel_position_multihead_attention():
|
||||
x = torch.randn(N, T, C)
|
||||
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
print("pos_enc.shape=", pos_enc.shape)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder_layer():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_transformer():
|
||||
return
|
||||
num_features = 40
|
||||
|
Loading…
x
Reference in New Issue
Block a user