Fix batch_size position bug in layer_skip
This commit is contained in:
parent
686e7e8828
commit
e0b8a0cfd0
@ -401,11 +401,13 @@ class Zipformer(EncoderInterface):
|
|||||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||||
# completely ignoring the middle layers, especially early in
|
# completely ignoring the middle layers, especially early in
|
||||||
# training,
|
# training,
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[1]
|
||||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||||
if self.training:
|
|
||||||
mask = (torch.rand((batch_size, 1, 1), device=x.device) >
|
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
||||||
float(self.layer_skip_dropout_prob))
|
if self.training and layer_skip_dropout_prob > 0:
|
||||||
|
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
||||||
|
layer_skip_dropout_prob)
|
||||||
x = torch.where(mask, skip_x, x)
|
x = torch.where(mask, skip_x, x)
|
||||||
else:
|
else:
|
||||||
x = skip_x
|
x = skip_x
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user