mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
Add testing for MaskedLmConformerEncoder
This commit is contained in:
parent
2fbe3b78fd
commit
556fae586f
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
# Apache 2.0
|
# Apache 2.0
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -910,7 +911,7 @@ class MaskedLmConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _get_clones(module, N):
|
def _get_clones(module, N):
|
||||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
class MaskedLmConformerEncoder(nn.Module):
|
class MaskedLmConformerEncoder(nn.Module):
|
||||||
r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from
|
r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
from conformer import (
|
from conformer import (
|
||||||
TransformerDecoderRelPos,
|
TransformerDecoderRelPos,
|
||||||
MaskedLmConformer,
|
MaskedLmConformer,
|
||||||
|
MaskedLmConformerEncoder,
|
||||||
MaskedLmConformerEncoderLayer,
|
MaskedLmConformerEncoderLayer,
|
||||||
RelPositionMultiheadAttention,
|
RelPositionMultiheadAttention,
|
||||||
RelPositionalEncoding,
|
RelPositionalEncoding,
|
||||||
@ -50,6 +51,28 @@ def test_masked_lm_conformer_encoder_layer():
|
|||||||
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer_encoder():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||||
|
norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
|
||||||
|
norm=norm)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_enc = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
y = encoder(x, pos_enc, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_transformer():
|
def test_transformer():
|
||||||
return
|
return
|
||||||
num_features = 40
|
num_features = 40
|
||||||
|
Loading…
x
Reference in New Issue
Block a user