mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Merge branch 'rework2i' into rework2i_restoredrop
This commit is contained in:
commit
9a0c2e7fee
@ -54,7 +54,6 @@ class Conformer(EncoderInterface):
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
aux_layer_period: int = 3
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
@ -79,8 +78,7 @@ class Conformer(EncoderInterface):
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
self.final_dropout = nn.Dropout(p=dropout)
|
||||
if output_dim == d_model:
|
||||
@ -279,16 +277,13 @@ class ConformerEncoder(nn.Module):
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int,
|
||||
aux_layers: Sequence[int]) -> None:
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.num_layers = num_layers
|
||||
num_channels = encoder_layer.d_model
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user