Divide feature_mask into 3 groups

This commit is contained in:
Daniel Povey 2023-03-29 16:08:03 +08:00
parent b8f0756133
commit f1dbf4222e

View File

@ -299,7 +299,7 @@ class Zipformer2(EncoderInterface):
# we divide the dropped-out feature dimensions into two equal groups; # we divide the dropped-out feature dimensions into two equal groups;
# the first group is dropped out with this probability, the second # the first group is dropped out with this probability, the second
# group is dropped out with about twice this probability. # group is dropped out with about twice this probability.
feature_mask_dropout_prob = 0.125 feature_mask_dropout_prob = 0.1
# frame_mask_max1 shape: (num_frames_max, batch_size, 1) # frame_mask_max1 shape: (num_frames_max, batch_size, 1)
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1, frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
@ -312,8 +312,15 @@ class Zipformer2(EncoderInterface):
device=x.device) > device=x.device) >
feature_mask_dropout_prob).to(x.dtype)) feature_mask_dropout_prob).to(x.dtype))
# dim: (num_frames_max, batch_size, 2)
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1) # frame_mask_max3 has additional frames masked, about 3 times the number.
frame_mask_max3 = torch.logical_or(frame_mask_max2,
(torch.rand(num_frames_max, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype))
# dim: (num_frames_max, batch_size, 3)
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2, frame_mask_max3), dim=-1)
feature_masks = [] feature_masks = []
for i in range(num_encoders): for i in range(num_encoders):
@ -321,16 +328,19 @@ class Zipformer2(EncoderInterface):
upsample_factor = (max_downsampling_factor // ds) upsample_factor = (max_downsampling_factor // ds)
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
batch_size, 2) batch_size, 3)
.reshape(num_frames_max * upsample_factor, batch_size, 2)) .reshape(num_frames_max * upsample_factor, batch_size, 3))
num_frames = (num_frames0 + ds - 1) // ds num_frames = (num_frames0 + ds - 1) // ds
frame_mask = frame_mask[:num_frames] frame_mask = frame_mask[:num_frames]
feature_mask = torch.ones(num_frames, batch_size, self.encoder_dim[i], channels = self.encoder_unmasked_dim[i]
feature_mask = torch.ones(num_frames, batch_size, channels,
dtype=x.dtype, device=x.device) dtype=x.dtype, device=x.device)
u1 = self.encoder_unmasked_dim[i] u1 = self.encoder_unmasked_dim[i]
u2 = (u1 + self.encoder_dim[i]) // 2 u2 = u1 + (channels - u1) // 3
u3 = u1 + 2 * (channels - u1) // 3
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1] feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
feature_mask[:, :, u2:] *= frame_mask[..., 1:2] feature_mask[:, :, u2:u3] *= frame_mask[..., 1:2]
feature_mask[:, :, u3:channels] *= frame_mask[..., 2:3]
feature_masks.append(feature_mask) feature_masks.append(feature_mask)
return feature_masks return feature_masks