diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index f47183e2b..dde0940c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -343,7 +343,20 @@ class ConformerEncoder(nn.Module): output = output * feature_mask - for i, mod in enumerate(self.layers): + + num_layers = len(self.layers) + indexes = list(range(num_layers)) + if self.training: + num_to_swap = int((num_layers - 1) * 0.1) + for i in range(num_to_swap): + j = random.randint(0, num_layers - 2) + a,b = indexes[j], indexes[j+1] + indexes[j] = b + indexes[j+1] = a + + + for i,j in enumerate(indexes): + mod = self.layers[j] output, attn_scores = mod( output, pos_emb,