from local

This commit is contained in:
dohe0342 2023-02-14 17:49:40 +09:00
parent d90ec45e33
commit e77b0077b3
2 changed files with 4 additions and 1 deletions

View File

@ -249,7 +249,10 @@ class Transformer(nn.Module):
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (S, N, C) -> (N, S, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, S, C)
if log_prob:
x = nn.functional.log_softmax(x, dim=-1) # (N, S, C)
else:
x = nn.functional.softmax(x, dim=-1)
return x
@torch.jit.export