from local

This commit is contained in:
dohe0342 2023-02-02 14:16:38 +09:00
parent b67c078fb4
commit 685aaf5c7b
4 changed files with 2 additions and 2 deletions

View File

@ -224,7 +224,7 @@ class Transformer(nn.Module):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions) mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None mask = mask.to(x.device) if mask is not None else None
x, layer_outputs= self.encoder(x, src_key_padding_mask=mask) # (T, N, C) x, layer_outputs = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
return x, mask return x, mask

View File

@ -205,7 +205,7 @@ class Conformer(EncoderInterface):
x = 0 x = 0
for enum, alpha in enumerate(self.alpha): for enum, alpha in enumerate(self.alpha):
x += self.sigmoid(alpha*layer_outputs[(enum+1)*self.group_layer_num-1]) x += self.sigmoid(alpha)*layer_outputs[(enum+1)*self.group_layer_num-1]
x = self.layer_norm(x/self.group_num) x = self.layer_norm(x/self.group_num)