mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Changes to frame masking: done at sequence level, with 2 dimension cutoffs
This commit is contained in:
parent
fb6a1c1464
commit
73099da6be
@ -276,7 +276,7 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
(num_frames, batch_size, encoder_dims0)
|
||||
(1, batch_size, encoder_dims0)
|
||||
"""
|
||||
num_encoders = len(self.encoder_dim)
|
||||
if not self.training:
|
||||
@ -286,38 +286,34 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
assert self.encoder_dim[0] == _encoder_dims0
|
||||
|
||||
feature_mask_dropout_prob = 0.125
|
||||
|
||||
# this is a heuristic, the regions over which we masked frames are 4 times the
|
||||
# largest of the downsampling_factors. Can be any integer >= 1.
|
||||
downsampling_multiple = 4
|
||||
# mask1 shape: (1, batch_size, 1)
|
||||
mask1 = (torch.rand(1, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
group_size = max(self.downsampling_factor) * downsampling_multiple
|
||||
# mask2 has additional sequences masked, about twice the number.
|
||||
mask2 = torch.logical_and(mask1,
|
||||
(torch.rand(1, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype))
|
||||
|
||||
num_groups = (num_frames0 + group_size - 1) // group_size
|
||||
|
||||
feature_mask_dropout_prob = 0.2
|
||||
|
||||
# shape: (num_groups, batch_size, 1)
|
||||
group_mask = (torch.rand(num_groups, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
# dim: (1, batch_size, 2)
|
||||
mask = torch.cat((mask1, mask2), dim=-1)
|
||||
|
||||
feature_masks = []
|
||||
for i in range(num_encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
frames_per_group = (group_size // ds)
|
||||
|
||||
frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group,
|
||||
batch_size, 1)
|
||||
.reshape(num_groups * frames_per_group, batch_size, 1))
|
||||
num_frames = (num_frames0 + ds - 1) // ds
|
||||
frame_mask = frame_mask[:num_frames]
|
||||
channels = self.encoder_dim[i]
|
||||
feature_mask = torch.ones(num_frames, batch_size, channels,
|
||||
dtype=x.dtype, device=x.device)
|
||||
feature_mask = torch.ones(1, batch_size, channels,
|
||||
dtype=x.dtype, device=x.device)
|
||||
u1 = self.encoder_unmasked_dim[i]
|
||||
feature_mask[:, :, u1:] *= frame_mask
|
||||
u2 = u1 + (channels - u1) // 2
|
||||
|
||||
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
||||
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
||||
|
||||
feature_masks.append(feature_mask)
|
||||
|
||||
return feature_masks
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user