Merge branch 'rework2i' into rework2i_restoredrop

This commit is contained in:
Daniel Povey 2022-03-31 12:17:02 +08:00
commit 9a0c2e7fee

View File

@ -54,7 +54,6 @@ class Conformer(EncoderInterface):
num_encoder_layers: int = 12, num_encoder_layers: int = 12,
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
aux_layer_period: int = 3
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -79,8 +78,7 @@ class Conformer(EncoderInterface):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.final_dropout = nn.Dropout(p=dropout) self.final_dropout = nn.Dropout(p=dropout)
if output_dim == d_model: if output_dim == d_model:
@ -279,16 +277,13 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__(self, encoder_layer: nn.Module, num_layers: int, def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
aux_layers: Sequence[int]) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers
self.num_layers = num_layers self.num_layers = num_layers
num_channels = encoder_layer.d_model
def forward( def forward(
self, self,