Fix batch_size position bug in layer_skip

This commit is contained in:
Daniel Povey 2023-02-16 15:13:06 +08:00
parent 686e7e8828
commit e0b8a0cfd0

View File

@ -401,11 +401,13 @@ class Zipformer(EncoderInterface):
# stacks. The layer_skip_dropout_prob is to discourage it from
# completely ignoring the middle layers, especially early in
# training,
batch_size = x.shape[0]
batch_size = x.shape[1]
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
if self.training:
mask = (torch.rand((batch_size, 1, 1), device=x.device) >
float(self.layer_skip_dropout_prob))
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
if self.training and layer_skip_dropout_prob > 0:
mask = (torch.rand((1, batch_size, 1), device=x.device) >
layer_skip_dropout_prob)
x = torch.where(mask, skip_x, x)
else:
x = skip_x