diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index d8b1261a6..2d3c9a6a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -378,8 +378,8 @@ class ConformerEncoder(nn.Module): indexes[j+1] = a - for i in indexes: - mod = self.layers[i] + for i,j in enumerate(indexes): + mod = self.layers[j] output, attn_scores = mod( output, pos_emb, @@ -391,12 +391,10 @@ class ConformerEncoder(nn.Module): ) if i in self.aux_layers: outputs.append(output) - if i == num_layers - 1: - final_output = output output = self.combiner(outputs) - output = final_output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop + output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop return output