mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Changes to frame masking: done at sequence level, with 2 dimension cutoffs
This commit is contained in:
parent
fb6a1c1464
commit
73099da6be
@ -276,7 +276,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||||
(num_frames, batch_size, encoder_dims0)
|
(1, batch_size, encoder_dims0)
|
||||||
"""
|
"""
|
||||||
num_encoders = len(self.encoder_dim)
|
num_encoders = len(self.encoder_dim)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
@ -286,38 +286,34 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
assert self.encoder_dim[0] == _encoder_dims0
|
assert self.encoder_dim[0] == _encoder_dims0
|
||||||
|
|
||||||
|
feature_mask_dropout_prob = 0.125
|
||||||
|
|
||||||
# this is a heuristic, the regions over which we masked frames are 4 times the
|
# mask1 shape: (1, batch_size, 1)
|
||||||
# largest of the downsampling_factors. Can be any integer >= 1.
|
mask1 = (torch.rand(1, batch_size, 1,
|
||||||
downsampling_multiple = 4
|
|
||||||
|
|
||||||
group_size = max(self.downsampling_factor) * downsampling_multiple
|
|
||||||
|
|
||||||
num_groups = (num_frames0 + group_size - 1) // group_size
|
|
||||||
|
|
||||||
feature_mask_dropout_prob = 0.2
|
|
||||||
|
|
||||||
# shape: (num_groups, batch_size, 1)
|
|
||||||
group_mask = (torch.rand(num_groups, batch_size, 1,
|
|
||||||
device=x.device) >
|
device=x.device) >
|
||||||
feature_mask_dropout_prob).to(x.dtype)
|
feature_mask_dropout_prob).to(x.dtype)
|
||||||
|
|
||||||
|
# mask2 has additional sequences masked, about twice the number.
|
||||||
|
mask2 = torch.logical_and(mask1,
|
||||||
|
(torch.rand(1, batch_size, 1,
|
||||||
|
device=x.device) >
|
||||||
|
feature_mask_dropout_prob).to(x.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
# dim: (1, batch_size, 2)
|
||||||
|
mask = torch.cat((mask1, mask2), dim=-1)
|
||||||
|
|
||||||
feature_masks = []
|
feature_masks = []
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
ds = self.downsampling_factor[i]
|
|
||||||
frames_per_group = (group_size // ds)
|
|
||||||
|
|
||||||
frame_mask = (group_mask.unsqueeze(1).expand(num_groups, frames_per_group,
|
|
||||||
batch_size, 1)
|
|
||||||
.reshape(num_groups * frames_per_group, batch_size, 1))
|
|
||||||
num_frames = (num_frames0 + ds - 1) // ds
|
|
||||||
frame_mask = frame_mask[:num_frames]
|
|
||||||
channels = self.encoder_dim[i]
|
channels = self.encoder_dim[i]
|
||||||
feature_mask = torch.ones(num_frames, batch_size, channels,
|
feature_mask = torch.ones(1, batch_size, channels,
|
||||||
dtype=x.dtype, device=x.device)
|
dtype=x.dtype, device=x.device)
|
||||||
u1 = self.encoder_unmasked_dim[i]
|
u1 = self.encoder_unmasked_dim[i]
|
||||||
feature_mask[:, :, u1:] *= frame_mask
|
u2 = u1 + (channels - u1) // 2
|
||||||
|
|
||||||
|
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
||||||
|
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
||||||
|
|
||||||
feature_masks.append(feature_mask)
|
feature_masks.append(feature_mask)
|
||||||
|
|
||||||
return feature_masks
|
return feature_masks
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user