mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix or vs. and bug
This commit is contained in:
parent
a02199df79
commit
6e058b9ebd
@ -307,17 +307,17 @@ class Zipformer2(EncoderInterface):
|
|||||||
feature_mask_dropout_prob).to(x.dtype)
|
feature_mask_dropout_prob).to(x.dtype)
|
||||||
|
|
||||||
# frame_mask_max2 has additional frames masked, about twice the number.
|
# frame_mask_max2 has additional frames masked, about twice the number.
|
||||||
frame_mask_max2 = torch.logical_or(frame_mask_max1,
|
frame_mask_max2 = torch.logical_and(frame_mask_max1,
|
||||||
(torch.rand(num_frames_max, batch_size, 1,
|
(torch.rand(num_frames_max, batch_size, 1,
|
||||||
device=x.device) >
|
device=x.device) >
|
||||||
feature_mask_dropout_prob).to(x.dtype))
|
feature_mask_dropout_prob).to(x.dtype))
|
||||||
|
|
||||||
|
|
||||||
# frame_mask_max3 has additional frames masked, about 3 times the number.
|
# frame_mask_max3 has additional frames masked, about 3 times the number.
|
||||||
frame_mask_max3 = torch.logical_or(frame_mask_max2,
|
frame_mask_max3 = torch.logical_and(frame_mask_max2,
|
||||||
(torch.rand(num_frames_max, batch_size, 1,
|
(torch.rand(num_frames_max, batch_size, 1,
|
||||||
device=x.device) >
|
device=x.device) >
|
||||||
feature_mask_dropout_prob).to(x.dtype))
|
feature_mask_dropout_prob).to(x.dtype))
|
||||||
|
|
||||||
# dim: (num_frames_max, batch_size, 3)
|
# dim: (num_frames_max, batch_size, 3)
|
||||||
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)
|
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user