From 1603744469d167d848e074f2ea98c587153205fa Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 5 Mar 2022 19:26:06 +0800 Subject: [PATCH] Refactor conformer. (#237) --- .../ASR/transducer_stateless/conformer.py | 19 +++---- .../transducer_stateless/test_conformer.py | 51 +++++++++++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) create mode 100755 egs/librispeech/ASR/transducer_stateless/test_conformer.py diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..fc838f75b 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings from typing import Optional, Tuple @@ -264,13 +264,12 @@ class ConformerEncoderLayer(nn.Module): return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -280,12 +279,12 @@ class ConformerEncoder(nn.TransformerEncoder): >>> out = conformer_encoder(src, pos_emb) """ - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) + self.num_layers = num_layers def forward( self, @@ -320,9 +319,6 @@ class ConformerEncoder(nn.TransformerEncoder): src_key_padding_mask=src_key_padding_mask, ) - if self.norm is not None: - output = self.norm(output) - return output @@ -643,6 +639,7 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py new file mode 100755 index 000000000..d1350c8ab --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_conformer.py +""" + +import torch +from conformer import Conformer + + +def test_conformer(): + feature_dim = 50 + c = Conformer( + num_features=feature_dim, output_dim=256, d_model=128, nhead=4 + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + logits, lengths = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + print(logits.shape) + print(lengths.shape) + + +def main(): + test_conformer() + + +if __name__ == "__main__": + main()