mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce feature mask per frame
This commit is contained in:
parent
056b9a4f9a
commit
38f89053bd
@ -215,6 +215,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_scores_in: Optional[Tensor] = None,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
feature_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
@ -227,6 +228,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
passed from layer to layer.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
feature_mask: a mask of shape (S, N, E), that randomly zeroes out
|
||||
some of the features on each frame.
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
|
||||
@ -235,6 +238,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
feature_mask: (S, N, E)
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
src_orig = src
|
||||
@ -275,6 +279,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
if feature_mask is not None:
|
||||
src = src * feature_mask
|
||||
|
||||
return src, attn_scores_out
|
||||
|
||||
|
||||
@ -344,6 +351,20 @@ class ConformerEncoder(nn.Module):
|
||||
outputs = []
|
||||
attn_scores = None
|
||||
|
||||
|
||||
if self.training:
|
||||
# feature mask.
|
||||
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
|
||||
feature_mask_dropout_prob = 0.25
|
||||
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
||||
|
||||
feature_mask = torch.ones_like(src) # S, N, E
|
||||
# is_masked_frame is 0 with probability `feature_mask_dropout_prob`
|
||||
is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||
feature_mask[..., feature_unmasked_dim:] *= is_masked_frame
|
||||
else:
|
||||
feature_mask = None
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output, attn_scores = mod(
|
||||
output,
|
||||
@ -351,6 +372,7 @@ class ConformerEncoder(nn.Module):
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
feature_mask=feature_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
if i in self.aux_layers:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user