mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Have 2 not 3 groups, but give 1st group a smaller dropout prob than the 2nd.
This commit is contained in:
parent
6e058b9ebd
commit
e64ec396bd
@ -297,30 +297,25 @@ class Zipformer2(EncoderInterface):
|
|||||||
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
||||||
|
|
||||||
# we divide the dropped-out feature dimensions into two equal groups;
|
# we divide the dropped-out feature dimensions into two equal groups;
|
||||||
# the first group is dropped out with this probability, the second
|
# the first group is dropped out with probability 0.05, the second
|
||||||
# group is dropped out with about twice this probability.
|
# with probability approximately (0.2 + 0.05).
|
||||||
feature_mask_dropout_prob = 0.1
|
feature_mask_dropout_prob1 = 0.05
|
||||||
|
feature_mask_dropout_prob2 = 0.2
|
||||||
|
|
||||||
# frame_mask_max1 shape: (num_frames_max, batch_size, 1)
|
# frame_mask_max1 shape: (num_frames_max, batch_size, 1)
|
||||||
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
|
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
|
||||||
device=x.device) >
|
device=x.device) >
|
||||||
feature_mask_dropout_prob).to(x.dtype)
|
feature_mask_dropout_prob1).to(x.dtype)
|
||||||
|
|
||||||
# frame_mask_max2 has additional frames masked, about twice the number.
|
# frame_mask_max2 has additional frames masked, about twice the number.
|
||||||
frame_mask_max2 = torch.logical_and(frame_mask_max1,
|
frame_mask_max2 = torch.logical_and(frame_mask_max1,
|
||||||
(torch.rand(num_frames_max, batch_size, 1,
|
(torch.rand(num_frames_max, batch_size, 1,
|
||||||
device=x.device) >
|
device=x.device) >
|
||||||
feature_mask_dropout_prob).to(x.dtype))
|
feature_mask_dropout_prob2).to(x.dtype))
|
||||||
|
|
||||||
|
|
||||||
# frame_mask_max3 has additional frames masked, about 3 times the number.
|
|
||||||
frame_mask_max3 = torch.logical_and(frame_mask_max2,
|
|
||||||
(torch.rand(num_frames_max, batch_size, 1,
|
|
||||||
device=x.device) >
|
|
||||||
feature_mask_dropout_prob).to(x.dtype))
|
|
||||||
|
|
||||||
# dim: (num_frames_max, batch_size, 3)
|
# dim: (num_frames_max, batch_size, 3)
|
||||||
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)
|
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1)
|
||||||
|
|
||||||
feature_masks = []
|
feature_masks = []
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
@ -328,19 +323,17 @@ class Zipformer2(EncoderInterface):
|
|||||||
upsample_factor = (max_downsampling_factor // ds)
|
upsample_factor = (max_downsampling_factor // ds)
|
||||||
|
|
||||||
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
||||||
batch_size, 3)
|
batch_size, 2)
|
||||||
.reshape(num_frames_max * upsample_factor, batch_size, 3))
|
.reshape(num_frames_max * upsample_factor, batch_size, 2))
|
||||||
num_frames = (num_frames0 + ds - 1) // ds
|
num_frames = (num_frames0 + ds - 1) // ds
|
||||||
frame_mask = frame_mask[:num_frames]
|
frame_mask = frame_mask[:num_frames]
|
||||||
channels = self.encoder_dim[i]
|
channels = self.encoder_dim[i]
|
||||||
feature_mask = torch.ones(num_frames, batch_size, channels,
|
feature_mask = torch.ones(num_frames, batch_size, channels,
|
||||||
dtype=x.dtype, device=x.device)
|
dtype=x.dtype, device=x.device)
|
||||||
u1 = self.encoder_unmasked_dim[i]
|
u1 = self.encoder_unmasked_dim[i]
|
||||||
u2 = u1 + (channels - u1) // 3
|
u2 = u1 + (channels - u1) // 2
|
||||||
u3 = u1 + 2 * (channels - u1) // 3
|
|
||||||
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
|
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
|
||||||
feature_mask[:, :, u2:u3] *= frame_mask[..., 1:2]
|
feature_mask[:, :, u2:] *= frame_mask[..., 1:2]
|
||||||
feature_mask[:, :, u3:channels] *= frame_mask[..., 2:3]
|
|
||||||
feature_masks.append(feature_mask)
|
feature_masks.append(feature_mask)
|
||||||
|
|
||||||
return feature_masks
|
return feature_masks
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user