diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 2c74a23a6..52f691ac7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -342,7 +342,7 @@ class ConformerEncoder(nn.Module): # This ensures that if we are using multiple worker processes, they all use the same # random numbers, so they will all take about the same amount of time to process # the batch. - r = random.Random(self.count) + rng = random.Random(self.count) self.count += 1 def get_random_mask(): @@ -350,7 +350,7 @@ class ConformerEncoder(nn.Module): mask = torch.ones(num_layers, device='cpu') if self.training: return mask - r = r.random() + r = rng.random() if r < 0.1: # drop zero layers, to make sure that sometimes we see the complete network. return mask @@ -358,15 +358,16 @@ class ConformerEncoder(nn.Module): if r < 0.1 + 0.25: # with prob 0.25: completely drop the last n layers. let n # be a multiple of 3 (this is what we used to do with aux_layers). - final_layers_dropped = 3 * r.randint(1, num_layers // 3) + final_layers_dropped = 3 * rng.randint(1, num_layers // 3) mask[-final_layers_dropped:] = 0.0 layer_drop_prob = 0.075 for i in range(final_layers_dropped): - mask[i] = (r.random() > layer_drop_prob) - + mask[i] = (rng.random() > layer_drop_prob) if mask.sum() == 0.0: mask[0] = 1.0 + return mask + mask = get_random_mask() device = self.to_layerdrop_scales[0].weight.device layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device)) @@ -1257,6 +1258,13 @@ def _test_conformer_main(): warmup=0.5, ) f # to remove flake8 warnings + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + f # to remove flake8 warnings