Changes to frame masking: done at sequence level, with 2 dimension cutoffs

This commit is contained in:
Daniel Povey 2023-04-10 15:30:56 +08:00
parent fb6a1c1464
commit 73099da6be

View File

@ -276,7 +276,7 @@ class Zipformer2(EncoderInterface):
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(num_frames, batch_size, encoder_dims0)
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dim)
if not self.training:
@ -286,38 +286,34 @@ class Zipformer2(EncoderInterface):
assert self.encoder_dim[0] == _encoder_dims0
feature_mask_dropout_prob = 0.125
# this is a heuristic, the regions over which we masked frames are 4 times the
# largest of the downsampling_factors. Can be any integer >= 1.
downsampling_multiple = 4
# mask1 shape: (1, batch_size, 1)
mask1 = (torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
group_size = max(self.downsampling_factor) * downsampling_multiple
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(mask1,
(torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype))
num_groups = (num_frames0 + group_size - 1) // group_size
feature_mask_dropout_prob = 0.2
# shape: (num_groups, batch_size, 1)
group_mask = (torch.rand(num_groups, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
ds = self.downsampling_factor[i]
frames_per_group = (group_size // ds)
frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group,
batch_size, 1)
.reshape(num_groups * frames_per_group, batch_size, 1))
num_frames = (num_frames0 + ds - 1) // ds
frame_mask = frame_mask[:num_frames]
channels = self.encoder_dim[i]
feature_mask = torch.ones(num_frames, batch_size, channels,
dtype=x.dtype, device=x.device)
feature_mask = torch.ones(1, batch_size, channels,
dtype=x.dtype, device=x.device)
u1 = self.encoder_unmasked_dim[i]
feature_mask[:, :, u1:] *= frame_mask
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks