from local

This commit is contained in:
dohe0342 2023-02-25 15:51:57 +09:00
parent 3d87d438fc
commit ef9c392c5b
2 changed files with 5 additions and 2 deletions

View File

@ -322,8 +322,11 @@ class Transformer(nn.Module):
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
if return_output:
return pred_pad, decoder_loss
else:
return decoder_loss
@torch.jit.export
def decoder_nll(