from local

This commit is contained in:
dohe0342 2023-02-02 14:15:21 +09:00
parent cec9ab41fa
commit b67c078fb4
2 changed files with 1 additions and 1 deletions

View File

@ -224,7 +224,7 @@ class Transformer(nn.Module):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions) mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) x, layer_outputs= self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
return x, mask return x, mask