From 556fae586fdaf078253e7213e2fd68acf9b74ae4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 17:22:03 +0800 Subject: [PATCH] Add testing for MaskedLmConformerEncoder --- egs/librispeech/ASR/conformer_lm/conformer.py | 3 ++- .../ASR/conformer_lm/test_conformer.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 35d700119..163f47543 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -1,6 +1,7 @@ # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 +import copy import math from typing import Dict, List, Optional, Tuple @@ -910,7 +911,7 @@ class MaskedLmConformerEncoderLayer(nn.Module): def _get_clones(module, N): - return ModuleList([copy.deepcopy(module) for i in range(N)]) + return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class MaskedLmConformerEncoder(nn.Module): r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 3fdd8a222..8c2b2efa4 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -6,6 +6,7 @@ import torch from conformer import ( TransformerDecoderRelPos, MaskedLmConformer, + MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, RelPositionMultiheadAttention, RelPositionalEncoding, @@ -50,6 +51,28 @@ def test_masked_lm_conformer_encoder_layer(): y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) +def test_masked_lm_conformer_encoder(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads) + norm = torch.nn.LayerNorm(embed_dim) + encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4, + norm=norm) + + + x = torch.randn(N, T, C) + x, pos_enc = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + y = encoder(x, pos_enc, key_padding_mask=key_padding_mask) + + + def test_transformer(): return num_features = 40