mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Sometimes mask more frames.
This commit is contained in:
parent
4e36656cef
commit
bb8cbd7598
@ -296,27 +296,38 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
||||
|
||||
feature_mask_dropout_prob = 0.15
|
||||
feature_mask_dropout_prob = 0.125
|
||||
|
||||
# frame_mask_max shape: (num_frames_max, batch_size, 1)
|
||||
frame_mask_max = (torch.rand(num_frames_max, batch_size, 1,
|
||||
# frame_mask_max1 shape: (num_frames_max, batch_size, 1)
|
||||
frame_mask_max1 = (torch.rand(num_frames_max, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
# frame_mask_max2 has additional frames masked, about twice the number.
|
||||
frame_mask_max2 = torch.logical_or(frame_mask_max1,
|
||||
(torch.rand(num_frames_max, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype))
|
||||
|
||||
# dim: (num_frames_max, batch_size, 2)
|
||||
frame_mask_max = torch.cat((frame_mask_max1, frame_mask_max2), dim=-1)
|
||||
|
||||
feature_masks = []
|
||||
for i in range(num_encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
upsample_factor = (max_downsampling_factor // ds)
|
||||
|
||||
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
||||
batch_size, 1)
|
||||
.reshape(num_frames_max * upsample_factor, batch_size, 1))
|
||||
batch_size, 2)
|
||||
.reshape(num_frames_max * upsample_factor, batch_size, 2))
|
||||
num_frames = (num_frames0 + ds - 1) // ds
|
||||
frame_mask = frame_mask[:num_frames]
|
||||
feature_mask = torch.ones(num_frames, batch_size, self.encoder_dim[i],
|
||||
dtype=x.dtype, device=x.device)
|
||||
u = self.encoder_unmasked_dim[i]
|
||||
feature_mask[:, :, u:] *= frame_mask
|
||||
u1 = self.encoder_unmasked_dim[i]
|
||||
u2 = (u1 + self.encoder_dim[i]) // 2
|
||||
feature_mask[:, :, u1:u2] *= frame_mask[..., 0:1]
|
||||
feature_mask[:, :, u2:] *= frame_mask[..., 1:2]
|
||||
feature_masks.append(feature_mask)
|
||||
|
||||
return feature_masks
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user