from local

This commit is contained in:
dohe0342 2023-01-21 15:33:21 +09:00
parent 23e54526f1
commit 1a91afa118
2 changed files with 10 additions and 6 deletions

View File

@ -64,6 +64,7 @@ class Conformer(Transformer):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
group_num: int = 0,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -101,12 +102,13 @@ class Conformer(Transformer):
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.group_num = group_num
self.group_layer_num = int(self.encoder_layers // self.group_num)
self.alpha = nn.Parameter(torch.rand(self.group_num))
self.sigmoid = nn.Sigmoid()
self.layer_norm = nn.LayerNorm(d_model)
if self.group_num != 0:
self.group_layer_num = int(self.encoder_layers // self.group_num)
self.alpha = nn.Parameter(torch.rand(self.group_num))
self.sigmoid = nn.Sigmoid()
self.layer_norm = nn.LayerNorm(d_model)
def run_encoder(
self,
@ -146,8 +148,10 @@ class Conformer(Transformer):
x, layer_outputs = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
if self.group_num != 0:
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
# return x, lengths
return x, mask