Various fixes, finish implementating frame masking

This commit is contained in:
Daniel Povey 2022-10-06 20:28:47 +08:00
parent e4c9786e4a
commit a3179c30e7
2 changed files with 95 additions and 33 deletions

View File

@ -41,7 +41,7 @@ class Conformer(EncoderInterface):
Args:
num_features (int): Number of input features
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): (attention_dim1, attention_dim2, output_dim)
d_model (int): embedding dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
@ -56,7 +56,8 @@ class Conformer(EncoderInterface):
num_features: int,
subsampling_factor: int = 4,
conformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (256, 384, 512),
d_model: Tuple[int] = (384, 384),
encoder_unmasked_dim: int = 256,
nhead: Tuple[int] = (8, 8),
feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
@ -67,6 +68,13 @@ class Conformer(EncoderInterface):
self.num_features = num_features
self.subsampling_factor = subsampling_factor
self.encoder_unmasked_dim = encoder_unmasked_dim
assert 0 < d_model[0] <= d_model[1]
self.d_model = d_model
self.conformer_subsampling_factor = conformer_subsampling_factor
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
@ -112,6 +120,64 @@ class Conformer(EncoderInterface):
self.out_combiner = SimpleCombiner(d_model[0],
d_model[1])
def get_feature_mask(
self,
x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]:
"""
In eval mode, returns 1.0; in training mode, returns two randomized feature masks
for the 1st and second encoders (which may run at different frame rates).
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.conformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(num_frames, batch_size, d_model0)
"""
if not self.training:
return 1.0, 1.0
d_model0, d_model1 = self.d_model
(num_frames0, batch_size, _d_model0) = x.shape
assert d_model0 == _d_model0
ds = self.conformer_subsampling_factor
num_frames1 = ((num_frames0 + ds - 1) // ds)
# on this proportion of the frames, drop out the extra features above
# self.encoder_unmasked_dim.
feature_mask_dropout_prob = 0.15
# we only apply the random frame masking on 90% of sequences; we leave the remaining 10%
# un-masked so that the model has seen un-masked data.
sequence_mask_dropout_prob = 0.9
# frame_mask is 0 with probability `feature_mask_dropout_prob`
# frame_mask1 shape: (num_frames1, batch_size, 1)
frame_mask1 = torch.logical_or(
torch.rand(num_frames1, batch_size, 1, device=x.device) > feature_mask_dropout_prob,
torch.rand(1, batch_size, 1, device=x.device) > sequence_mask_dropout_prob).to(x.dtype)
feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1],
dtype=x.dtype, device=x.device)
feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1
# frame_mask0 shape: (num_frames0, batch_size, 1)
frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape(
num_frames1 * ds, batch_size, 1)[:num_frames0]
print("frame_mask0 = ", frame_mask0.squeeze(-1))
feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0],
dtype=x.dtype, device=x.device)
feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0
return feature_mask0, feature_mask1
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor,
@ -140,18 +206,19 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
feature_mask0, feature_mask1 = self.get_feature_mask(x)
# x1:
x1 = self.encoder1(
x, src_key_padding_mask=mask,
x, feature_mask=feature_mask0, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[0]
x2 = self.encoder2(
x1, src_key_padding_mask=mask,
x1, feature_mask=feature_mask1, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[1]
x = self.out_combiner(x1, x2)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@ -319,6 +386,7 @@ class ConformerEncoder(nn.Module):
def forward(
self,
src: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
@ -326,6 +394,8 @@ class ConformerEncoder(nn.Module):
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
@ -344,33 +414,8 @@ class ConformerEncoder(nn.Module):
outputs = []
attn_scores = None
# deal with feature masking.
if not self.training:
feature_mask = 1.0
else:
# feature mask.
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
feature_mask_dropout_prob = 0.15
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
feature_mask = torch.ones_like(src) # S, N, E
# frame_mask is 0 with probability `feature_mask_dropout_prob`
# frame_mask shape: (S, N, 1)
frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
# for 10% of sequences, make the frame mask always-1, i.e. don't drop out any of
# the frames. This is to make sure the model sometimes "sees" the same types of
# un-perturbed sequences that it will see in test time.
frame_mask = torch.logical_or(frame_mask,
torch.rand_like(src[:,:1,:1]) < 0.1)
feature_mask[..., feature_unmasked_dim:] *= frame_mask
output = output * feature_mask
num_layers = len(self.layers)
indexes = list(range(num_layers))
if self.training:
@ -417,6 +462,7 @@ class DownsampledConformerEncoder(nn.Module):
def forward(self,
src: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
@ -424,6 +470,9 @@ class DownsampledConformerEncoder(nn.Module):
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer. feature_mask is expected to be already downsampled by
self.downsample_factor.
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional).
@ -446,7 +495,7 @@ class DownsampledConformerEncoder(nn.Module):
src_key_padding_mask = src_key_padding_mask[::ds]
src = self.encoder(
src, src_key_padding_mask=mask,
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
@ -540,7 +589,10 @@ class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
Args:
dim1: the dimension of the first input, e.g. 256
dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1.
The output will have the same dimension as dim2.
"""
def __init__(self,
dim1: int,
@ -1381,7 +1433,7 @@ def _test_conformer_main():
# Just make sure the forward pass runs.
c = Conformer(
num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4)
)
batch_size = 5
seq_len = 20

View File

@ -118,6 +118,15 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"and the output dim of the encoder",
)
parser.add_argument(
"--encoder-unmasked-dim",
type=int,
default=256,
help="Unmasked dimension in the encoder, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse."
)
parser.add_argument(
"--conformer-subsampling-factor",
type=int,
@ -416,6 +425,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
subsampling_factor=params.subsampling_factor,
conformer_subsampling_factor=params.conformer_subsampling_factor,
d_model=to_int_list(params.encoder_dims),
encoder_unmasked_dim=params.encoder_unmasked_dim,
nhead=to_int_list(params.nhead),
feedforward_dim=to_int_list(params.feedforward_dims),
num_encoder_layers=to_int_list(params.num_encoder_layers),