diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index e67a74d70..add243d55 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -18,6 +18,7 @@ import copy import math import warnings +import itertools from typing import List, Optional, Tuple, Union import logging import torch @@ -473,11 +474,10 @@ class ConformerEncoder(nn.Module): layers = list(range(num_layers)) independent_rng.shuffle(layers) - # go through the shuffled layers twice, in case, the first time round, - # we did not drop out the target number of layers. - layers = layers + layers - for layer in layers: - if independent_rng.random() < get_layerdrop_prob(layer): + + # go through the shuffled layers until we get the required number of samples. + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: ans.add(layer) if len(ans) == num_to_drop: break diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 544324148..0e5808369 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -323,7 +323,6 @@ class ScaledAdam(BatchedOptimizer): first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: - print(f"step = {step}") # Print some stats. # We don't reach here if step == 0 because we would have returned # above.