diff --git a/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp b/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp index 9e8ab3ca9..9311babfd 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index 737bd9e05..d172fadf8 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -322,8 +322,11 @@ class Transformer(nn.Module): pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss + + if return_output: + return pred_pad, decoder_loss + else: + return decoder_loss @torch.jit.export def decoder_nll(