diff --git a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp index bdf9ff5ee..bfa70faa6 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.train.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp index feb52c4b1..2fb51e3bf 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.transformer.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/transformer.py b/egs/tedlium2/ASR/conformer_ctc3/transformer.py index afc46e311..ed1bb192e 100644 --- a/egs/tedlium2/ASR/conformer_ctc3/transformer.py +++ b/egs/tedlium2/ASR/conformer_ctc3/transformer.py @@ -186,12 +186,15 @@ class Transformer(nn.Module): encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision, warmup ) + x = self.ctc_output(encoder_memory) + if type(encoder_memory) == tuple: (encoder_memory, layer_outputs) = encoder_memory layer_outputs = [self.ctc_output(x) for x in layer_outputs] - x = self.ctc_output(encoder_memory) - return (x, layer_outputs), encoder_memory, memory_key_padding_mask + return (x, layer_outputs), encoder_memory, memory_key_padding_mask + else: + return x, encoder_memory, memory_key_padding_mask def run_encoder( self,