Add testing for MaskedLmConformerEncoder

This commit is contained in:
Daniel Povey 2021-08-23 17:22:03 +08:00
parent 2fbe3b78fd
commit 556fae586f
2 changed files with 25 additions and 1 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import copy
import math
from typing import Dict, List, Optional, Tuple
@ -910,7 +911,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class MaskedLmConformerEncoder(nn.Module):
r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from

View File

@ -6,6 +6,7 @@ import torch
from conformer import (
TransformerDecoderRelPos,
MaskedLmConformer,
MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer,
RelPositionMultiheadAttention,
RelPositionalEncoding,
@ -50,6 +51,28 @@ def test_masked_lm_conformer_encoder_layer():
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
def test_masked_lm_conformer_encoder():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
norm = torch.nn.LayerNorm(embed_dim)
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
norm=norm)
x = torch.randn(N, T, C)
x, pos_enc = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
def test_transformer():
return
num_features = 40