diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7d4404e9d..d638631dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -401,11 +401,13 @@ class Zipformer(EncoderInterface): # stacks. The layer_skip_dropout_prob is to discourage it from # completely ignoring the middle layers, especially early in # training, - batch_size = x.shape[0] + batch_size = x.shape[1] skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) - if self.training: - mask = (torch.rand((batch_size, 1, 1), device=x.device) > - float(self.layer_skip_dropout_prob)) + + layer_skip_dropout_prob = float(self.layer_skip_dropout_prob) + if self.training and layer_skip_dropout_prob > 0: + mask = (torch.rand((1, batch_size, 1), device=x.device) > + layer_skip_dropout_prob) x = torch.where(mask, skip_x, x) else: x = skip_x