Introduce feature mask per frame

This commit is contained in:
Daniel Povey 2022-09-29 17:29:44 +08:00
parent 056b9a4f9a
commit 38f89053bd

View File

@ -215,6 +215,7 @@ class ConformerEncoderLayer(nn.Module):
attn_scores_in: Optional[Tensor] = None, attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
feature_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tensor: ) -> Tensor:
""" """
@ -227,6 +228,8 @@ class ConformerEncoderLayer(nn.Module):
passed from layer to layer. passed from layer to layer.
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
feature_mask: a mask of shape (S, N, E), that randomly zeroes out
some of the features on each frame.
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
@ -235,6 +238,7 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E)
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
feature_mask: (S, N, E)
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
""" """
src_orig = src src_orig = src
@ -275,6 +279,9 @@ class ConformerEncoderLayer(nn.Module):
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
if feature_mask is not None:
src = src * feature_mask
return src, attn_scores_out return src, attn_scores_out
@ -344,6 +351,20 @@ class ConformerEncoder(nn.Module):
outputs = [] outputs = []
attn_scores = None attn_scores = None
if self.training:
# feature mask.
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
feature_mask_dropout_prob = 0.25
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
feature_mask = torch.ones_like(src) # S, N, E
# is_masked_frame is 0 with probability `feature_mask_dropout_prob`
is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
feature_mask[..., feature_unmasked_dim:] *= is_masked_frame
else:
feature_mask = None
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output, attn_scores = mod( output, attn_scores = mod(
output, output,
@ -351,6 +372,7 @@ class ConformerEncoder(nn.Module):
attn_scores, attn_scores,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
feature_mask=feature_mask,
warmup=warmup, warmup=warmup,
) )
if i in self.aux_layers: if i in self.aux_layers: