mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add testing for MaskedLmConformerEncoder
This commit is contained in:
parent
2fbe3b78fd
commit
556fae586f
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -910,7 +911,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
class MaskedLmConformerEncoder(nn.Module):
|
||||
r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from conformer import (
|
||||
TransformerDecoderRelPos,
|
||||
MaskedLmConformer,
|
||||
MaskedLmConformerEncoder,
|
||||
MaskedLmConformerEncoderLayer,
|
||||
RelPositionMultiheadAttention,
|
||||
RelPositionalEncoding,
|
||||
@ -50,6 +51,28 @@ def test_masked_lm_conformer_encoder_layer():
|
||||
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||
norm = torch.nn.LayerNorm(embed_dim)
|
||||
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
|
||||
norm=norm)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_enc = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
|
||||
def test_transformer():
|
||||
return
|
||||
num_features = 40
|
||||
|
Loading…
x
Reference in New Issue
Block a user