Merge branch 'rework2i' into rework2i_restoredrop

This commit is contained in:
Daniel Povey 2022-03-31 12:17:02 +08:00
commit 9a0c2e7fee

View File

@ -54,7 +54,6 @@ class Conformer(EncoderInterface):
num_encoder_layers: int = 12,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
aux_layer_period: int = 3
) -> None:
super(Conformer, self).__init__()
@ -79,8 +78,7 @@ class Conformer(EncoderInterface):
dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.final_dropout = nn.Dropout(p=dropout)
if output_dim == d_model:
@ -279,16 +277,13 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int,
aux_layers: Sequence[int]) -> None:
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers
self.num_layers = num_layers
num_channels = encoder_layer.d_model
def forward(
self,