from local

This commit is contained in:
dohe0342 2023-01-21 15:32:09 +09:00
parent 44d4f9ef15
commit 23e54526f1
2 changed files with 6 additions and 0 deletions

View File

@ -102,6 +102,12 @@ class Conformer(Transformer):
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.group_num = group_num
self.group_layer_num = int(self.encoder_layers // self.group_num)
self.alpha = nn.Parameter(torch.rand(self.group_num))
self.sigmoid = nn.Sigmoid()
self.layer_norm = nn.LayerNorm(d_model)
def run_encoder(
self,
x: torch.Tensor,