diff --git a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp index cf35417bc..e392a0ead 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/transformer.py b/egs/tedlium2/ASR/conformer_ctc3/transformer.py index 17878ca90..4ff9062fe 100644 --- a/egs/tedlium2/ASR/conformer_ctc3/transformer.py +++ b/egs/tedlium2/ASR/conformer_ctc3/transformer.py @@ -249,7 +249,10 @@ class Transformer(nn.Module): """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (S, N, C) -> (N, S, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, S, C) + if log_prob: + x = nn.functional.log_softmax(x, dim=-1) # (N, S, C) + else: + x = nn.functional.softmax(x, dim=-1) return x @torch.jit.export