Refactor/simplify ConformerEncoder

This commit is contained in:
Daniel Povey 2022-02-27 13:56:15 +08:00
parent 581786a6d3
commit 63d8d935d4

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
@ -264,13 +264,12 @@ class ConformerEncoderLayer(nn.Module):
return src return src
class ConformerEncoder(nn.TransformerEncoder): class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers r"""ConformerEncoder is a stack of N encoder layers
Args: Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required). encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -281,11 +280,12 @@ class ConformerEncoder(nn.TransformerEncoder):
""" """
def __init__( def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None self, encoder_layer: nn.Module, num_layers: int
) -> None: ) -> None:
super(ConformerEncoder, self).__init__( super(ConformerEncoder, self).__init__()
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)])
) self.num_layers = num_layers
def forward( def forward(
self, self,
@ -320,9 +320,6 @@ class ConformerEncoder(nn.TransformerEncoder):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if self.norm is not None:
output = self.norm(output)
return output return output