mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
23e54526f1
commit
1a91afa118
Binary file not shown.
@ -64,6 +64,7 @@ class Conformer(Transformer):
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
group_num: int = 0,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -101,12 +102,13 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
|
||||
self.group_num = group_num
|
||||
self.group_layer_num = int(self.encoder_layers // self.group_num)
|
||||
self.alpha = nn.Parameter(torch.rand(self.group_num))
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
if self.group_num != 0:
|
||||
self.group_layer_num = int(self.encoder_layers // self.group_num)
|
||||
self.alpha = nn.Parameter(torch.rand(self.group_num))
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
@ -146,8 +148,10 @@ class Conformer(Transformer):
|
||||
x, layer_outputs = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
|
||||
if self.group_num != 0:
|
||||
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
|
||||
# return x, lengths
|
||||
return x, mask
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user