From e0b8a0cfd01a674429c7bc3097b3c426ac5cdbb8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 16 Feb 2023 15:13:06 +0800 Subject: [PATCH] Fix batch_size position bug in layer_skip --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7d4404e9d..d638631dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -401,11 +401,13 @@ class Zipformer(EncoderInterface): # stacks. The layer_skip_dropout_prob is to discourage it from # completely ignoring the middle layers, especially early in # training, - batch_size = x.shape[0] + batch_size = x.shape[1] skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) - if self.training: - mask = (torch.rand((batch_size, 1, 1), device=x.device) > - float(self.layer_skip_dropout_prob)) + + layer_skip_dropout_prob = float(self.layer_skip_dropout_prob) + if self.training and layer_skip_dropout_prob > 0: + mask = (torch.rand((1, batch_size, 1), device=x.device) > + layer_skip_dropout_prob) x = torch.where(mask, skip_x, x) else: x = skip_x