Refactor/simplify ConformerEncoder

This commit is contained in:
Daniel Povey 2022-02-27 13:56:15 +08:00
parent 581786a6d3
commit 63d8d935d4

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
from typing import Optional, Tuple
@ -264,13 +264,12 @@ class ConformerEncoderLayer(nn.Module):
return src
class ConformerEncoder(nn.TransformerEncoder):
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -281,11 +280,12 @@ class ConformerEncoder(nn.TransformerEncoder):
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
self, encoder_layer: nn.Module, num_layers: int
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
super(ConformerEncoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)])
self.num_layers = num_layers
def forward(
self,
@ -320,9 +320,6 @@ class ConformerEncoder(nn.TransformerEncoder):
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output