mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
3d87d438fc
commit
ef9c392c5b
Binary file not shown.
@ -322,8 +322,11 @@ class Transformer(nn.Module):
|
|||||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
||||||
|
|
||||||
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
||||||
|
|
||||||
return decoder_loss
|
if return_output:
|
||||||
|
return pred_pad, decoder_loss
|
||||||
|
else:
|
||||||
|
return decoder_loss
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def decoder_nll(
|
def decoder_nll(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user