diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/.conformer.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless5/.conformer.py.swp deleted file mode 100644 index 6b909d661..000000000 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless5/.conformer.py.swp and /dev/null differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 9b4b2a1ef..8d766dd37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -188,7 +188,7 @@ class Conformer(EncoderInterface): warmup=warmup, ) # (T, N, C) else: - x = self.encoder( + x, layer_output = self.encoder( x, pos_emb, src_key_padding_mask=src_key_padding_mask, @@ -197,6 +197,8 @@ class Conformer(EncoderInterface): x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + layer_output = [x.permute(1, 0, 2) for x in layer_output] + return x, lengths @torch.jit.export @@ -693,6 +695,8 @@ class ConformerEncoder(nn.Module): outputs = [] + layer_output = [] + for i, mod in enumerate(self.layers): output = mod( output, @@ -704,6 +708,8 @@ class ConformerEncoder(nn.Module): if i in self.aux_layers: outputs.append(output) + layer_output.append(output) + output = self.combiner(outputs) return output, layer_output