From 63d8d935d43b719a74bdaa5db3892e71a2b9fe69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:56:15 +0800 Subject: [PATCH] Refactor/simplify ConformerEncoder --- .../ASR/transducer_stateless/conformer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4627dd147..07b80076d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings from typing import Optional, Tuple @@ -264,13 +264,12 @@ class ConformerEncoderLayer(nn.Module): return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -281,11 +280,12 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, encoder_layer: nn.Module, num_layers: int ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) + super(ConformerEncoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.num_layers = num_layers + def forward( self, @@ -320,9 +320,6 @@ class ConformerEncoder(nn.TransformerEncoder): src_key_padding_mask=src_key_padding_mask, ) - if self.norm is not None: - output = self.norm(output) - return output