mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Various fixes, finish implementating frame masking
This commit is contained in:
parent
e4c9786e4a
commit
a3179c30e7
@ -41,7 +41,7 @@ class Conformer(EncoderInterface):
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): (attention_dim1, attention_dim2, output_dim)
|
||||
d_model (int): embedding dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
@ -56,7 +56,8 @@ class Conformer(EncoderInterface):
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
conformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (256, 384, 512),
|
||||
d_model: Tuple[int] = (384, 384),
|
||||
encoder_unmasked_dim: int = 256,
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
@ -67,6 +68,13 @@ class Conformer(EncoderInterface):
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim
|
||||
assert 0 < d_model[0] <= d_model[1]
|
||||
self.d_model = d_model
|
||||
self.conformer_subsampling_factor = conformer_subsampling_factor
|
||||
|
||||
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
|
||||
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
@ -112,6 +120,64 @@ class Conformer(EncoderInterface):
|
||||
self.out_combiner = SimpleCombiner(d_model[0],
|
||||
d_model[1])
|
||||
|
||||
def get_feature_mask(
|
||||
self,
|
||||
x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]:
|
||||
"""
|
||||
In eval mode, returns 1.0; in training mode, returns two randomized feature masks
|
||||
for the 1st and second encoders (which may run at different frame rates).
|
||||
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
||||
some supplied number, e.g. >256, so in effect on those frames we are using
|
||||
a smaller encoer dim.
|
||||
|
||||
We generate the random masks at this level because we want the 2 masks to 'agree'
|
||||
all the way up the encoder stack. This will mean that the 1st mask will have
|
||||
mask values repeated self.conformer_subsampling_factor times.
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
(num_frames, batch_size, d_model0)
|
||||
"""
|
||||
if not self.training:
|
||||
return 1.0, 1.0
|
||||
|
||||
d_model0, d_model1 = self.d_model
|
||||
(num_frames0, batch_size, _d_model0) = x.shape
|
||||
assert d_model0 == _d_model0
|
||||
ds = self.conformer_subsampling_factor
|
||||
num_frames1 = ((num_frames0 + ds - 1) // ds)
|
||||
|
||||
# on this proportion of the frames, drop out the extra features above
|
||||
# self.encoder_unmasked_dim.
|
||||
feature_mask_dropout_prob = 0.15
|
||||
|
||||
# we only apply the random frame masking on 90% of sequences; we leave the remaining 10%
|
||||
# un-masked so that the model has seen un-masked data.
|
||||
sequence_mask_dropout_prob = 0.9
|
||||
|
||||
# frame_mask is 0 with probability `feature_mask_dropout_prob`
|
||||
# frame_mask1 shape: (num_frames1, batch_size, 1)
|
||||
frame_mask1 = torch.logical_or(
|
||||
torch.rand(num_frames1, batch_size, 1, device=x.device) > feature_mask_dropout_prob,
|
||||
torch.rand(1, batch_size, 1, device=x.device) > sequence_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1],
|
||||
dtype=x.dtype, device=x.device)
|
||||
feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1
|
||||
|
||||
|
||||
# frame_mask0 shape: (num_frames0, batch_size, 1)
|
||||
frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape(
|
||||
num_frames1 * ds, batch_size, 1)[:num_frames0]
|
||||
|
||||
print("frame_mask0 = ", frame_mask0.squeeze(-1))
|
||||
|
||||
feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0],
|
||||
dtype=x.dtype, device=x.device)
|
||||
feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0
|
||||
|
||||
return feature_mask0, feature_mask1
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||
@ -140,18 +206,19 @@ class Conformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
feature_mask0, feature_mask1 = self.get_feature_mask(x)
|
||||
|
||||
# x1:
|
||||
x1 = self.encoder1(
|
||||
x, src_key_padding_mask=mask,
|
||||
x, feature_mask=feature_mask0, src_key_padding_mask=mask,
|
||||
) # (T, N, C) where C == d_model[0]
|
||||
|
||||
x2 = self.encoder2(
|
||||
x1, src_key_padding_mask=mask,
|
||||
x1, feature_mask=feature_mask1, src_key_padding_mask=mask,
|
||||
) # (T, N, C) where C == d_model[1]
|
||||
|
||||
x = self.out_combiner(x1, x2)
|
||||
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
@ -319,6 +386,7 @@ class ConformerEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
@ -326,6 +394,8 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer.
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
@ -344,33 +414,8 @@ class ConformerEncoder(nn.Module):
|
||||
outputs = []
|
||||
attn_scores = None
|
||||
|
||||
|
||||
# deal with feature masking.
|
||||
if not self.training:
|
||||
feature_mask = 1.0
|
||||
else:
|
||||
# feature mask.
|
||||
# on 0.25 of the frames, drop out the extra features [force a bottleneck.]
|
||||
feature_mask_dropout_prob = 0.15
|
||||
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
||||
|
||||
feature_mask = torch.ones_like(src) # S, N, E
|
||||
# frame_mask is 0 with probability `feature_mask_dropout_prob`
|
||||
# frame_mask shape: (S, N, 1)
|
||||
frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||
|
||||
# for 10% of sequences, make the frame mask always-1, i.e. don't drop out any of
|
||||
# the frames. This is to make sure the model sometimes "sees" the same types of
|
||||
# un-perturbed sequences that it will see in test time.
|
||||
frame_mask = torch.logical_or(frame_mask,
|
||||
torch.rand_like(src[:,:1,:1]) < 0.1)
|
||||
|
||||
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
||||
|
||||
|
||||
output = output * feature_mask
|
||||
|
||||
|
||||
num_layers = len(self.layers)
|
||||
indexes = list(range(num_layers))
|
||||
if self.training:
|
||||
@ -417,6 +462,7 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
@ -424,6 +470,9 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer. feature_mask is expected to be already downsampled by
|
||||
self.downsample_factor.
|
||||
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
||||
this, if we are to support it. Won't work correctly yet.
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
@ -446,7 +495,7 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
|
||||
src = self.encoder(
|
||||
src, src_key_padding_mask=mask,
|
||||
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
|
||||
)
|
||||
src = self.upsample(src)
|
||||
# remove any extra frames that are not a multiple of downsample_factor
|
||||
@ -540,7 +589,10 @@ class SimpleCombiner(torch.nn.Module):
|
||||
"""
|
||||
A very simple way of combining 2 vectors of 2 different dims, via a
|
||||
learned weighted combination in the shared part of the dim.
|
||||
|
||||
Args:
|
||||
dim1: the dimension of the first input, e.g. 256
|
||||
dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1.
|
||||
The output will have the same dimension as dim2.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim1: int,
|
||||
@ -1381,7 +1433,7 @@ def _test_conformer_main():
|
||||
# Just make sure the forward pass runs.
|
||||
|
||||
c = Conformer(
|
||||
num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
|
||||
num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4)
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
|
||||
@ -118,6 +118,15 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"and the output dim of the encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Unmasked dimension in the encoder, relates to augmentation during training. "
|
||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||
" worse."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-subsampling-factor",
|
||||
type=int,
|
||||
@ -416,6 +425,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||
d_model=to_int_list(params.encoder_dims),
|
||||
encoder_unmasked_dim=params.encoder_unmasked_dim,
|
||||
nhead=to_int_list(params.nhead),
|
||||
feedforward_dim=to_int_list(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user