mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix batch_size position bug in layer_skip
This commit is contained in:
parent
686e7e8828
commit
e0b8a0cfd0
@ -401,11 +401,13 @@ class Zipformer(EncoderInterface):
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||
# completely ignoring the middle layers, especially early in
|
||||
# training,
|
||||
batch_size = x.shape[0]
|
||||
batch_size = x.shape[1]
|
||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
if self.training:
|
||||
mask = (torch.rand((batch_size, 1, 1), device=x.device) >
|
||||
float(self.layer_skip_dropout_prob))
|
||||
|
||||
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
||||
if self.training and layer_skip_dropout_prob > 0:
|
||||
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
||||
layer_skip_dropout_prob)
|
||||
x = torch.where(mask, skip_x, x)
|
||||
else:
|
||||
x = skip_x
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user