Refactor zipformer for more flexibility so we can change number of encoder layers.

This commit is contained in:
Daniel Povey 2022-10-28 17:32:38 +08:00
parent e592a920b4
commit ed1b4d5e5d
2 changed files with 162 additions and 107 deletions

View File

@ -127,19 +127,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--encoder-unmasked-dim",
type=int,
default=256,
help="Unmasked dimension in the encoder, relates to augmentation during training. "
"--encoder-unmasked-dims",
type=str,
default="256,256",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
" worse."
)
parser.add_argument(
"--zipformer-subsampling-factor",
type=int,
default=2,
help="Subsampling factor for 2nd stack of encoder layers.",
"--zipformer-downsampling-factors",
type=str,
default="1,2",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
@ -437,10 +437,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
zipformer_subsampling_factor=params.zipformer_subsampling_factor,
d_model=to_int_tuple(params.encoder_dims),
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dim=params.encoder_unmasked_dim,
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),

View File

@ -64,9 +64,10 @@ class Zipformer(EncoderInterface):
num_features: int,
subsampling_factor: int = 4,
zipformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (384, 384),
encoder_dims: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dim: int = 256,
encoder_unmasked_dims: Tuple[int] = (256, 256),
zipformer_downsampling_factors: Tuple[int] = (1, 2),
nhead: Tuple[int] = (8, 8),
feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
@ -78,74 +79,69 @@ class Zipformer(EncoderInterface):
self.num_features = num_features
self.subsampling_factor = subsampling_factor
self.encoder_unmasked_dim = encoder_unmasked_dim
assert 0 < d_model[0] <= d_model[1]
self.d_model = d_model
self.zipformer_subsampling_factor = zipformer_subsampling_factor
self.encoder_unmasked_dims = encoder_unmasked_dims
assert 0 < encoder_dims[0] <= encoder_dims[1]
self.encoder_dims = encoder_dims
self.encoder_unmasked_dims = encoder_unmasked_dims
self.zipformer_downsampling_factors = zipformer_downsampling_factors
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
for u,d in zip(encoder_unmasked_dims, encoder_dims):
assert u <= d
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# to the shape (N, T//subsampling_factor, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
# (2) embedding: num_features -> encoder_dims
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0],
dropout=dropout)
encoder_layer1 = ZipformerEncoderLayer(
d_model[0],
attention_dim[0],
nhead[0],
feedforward_dim[0],
dropout,
cnn_module_kernel[0],
)
# for the first third of the warmup period, we let the Conv2dSubsampling
# layer learn something. then start warmup up the first and then the second
# encoder.
self.encoder1 = ZipformerEncoder(
encoder_layer1,
num_encoder_layers[0],
dropout,
warmup_begin=warmup_batches / 3,
warmup_end=warmup_batches * 2 / 3,
)
encoder_layer2 = ZipformerEncoderLayer(
d_model[1],
attention_dim[1],
nhead[1],
feedforward_dim[1],
dropout,
cnn_module_kernel[1],
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
encoders = []
)
self.encoder2 = DownsampledZipformerEncoder(
ZipformerEncoder(
encoder_layer2,
num_encoder_layers[1],
num_encoders = len(encoder_dims)
for i in range(num_encoders):
encoder_layer = ZipformerEncoderLayer(
encoder_dims[i],
attention_dim[i],
nhead[i],
feedforward_dim[i],
dropout,
warmup_begin=warmup_batches * 2 / 3,
warmup_end=warmup_batches,
),
input_dim=d_model[0],
output_dim=d_model[1],
downsample=zipformer_subsampling_factor,
)
cnn_module_kernel[i],
)
self.out_combiner = SimpleCombiner(d_model[0],
d_model[1])
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = ZipformerEncoder(
encoder_layer,
num_encoder_layers[i],
dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
)
def get_feature_mask(
if zipformer_downsampling_factors[i] != 1:
assert i > 0, "First zipformer layer cannot use downsampling"
encoder = DownsampledZipformerEncoder(
encoder,
input_dim=encoder_dims[i-1],
output_dim=encoder_dims[i],
downsample=zipformer_downsampling_factors[i],
)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
def get_feature_masks(
self,
x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]:
x: torch.Tensor) -> List[Union[float, Tensor]]:
"""
In eval mode, returns 1.0; in training mode, returns two randomized feature masks
for the 1st and second encoders (which may run at different frame rates).
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
@ -156,40 +152,46 @@ class Zipformer(EncoderInterface):
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(num_frames, batch_size, d_model0)
(num_frames, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dims)
if not self.training:
return 1.0, 1.0
return [ 1.0 ] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dims[0] == _encoder_dims0
max_downsampling_factor = max(self.zipformer_downsampling_factors)
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
d_model0, d_model1 = self.d_model
(num_frames0, batch_size, _d_model0) = x.shape
assert d_model0 == _d_model0
ds = self.zipformer_subsampling_factor
num_frames1 = ((num_frames0 + ds - 1) // ds)
# on this proportion of the frames, drop out the extra features above
# self.encoder_unmasked_dim.
feature_mask_dropout_prob = 0.15
# frame_mask1 shape: (num_frames1, batch_size, 1)
frame_mask1 = (torch.rand(num_frames1, batch_size, 1, device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# frame_mask_max shape: (num_frames_max, batch_size, 1)
frame_mask_max = (torch.rand(num_frames_max, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
feature_masks = []
for i in range(num_encoders):
ds = self.zipformer_downsampling_factors[i]
upsample_factor = (max_downsampling_factor // ds)
feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1],
dtype=x.dtype, device=x.device)
feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
batch_size, 1)
.reshape(num_frames_max * upsample_factor, batch_size, 1))
num_frames = (num_frames0 + ds - 1) // ds
frame_mask = frame_mask[:num_frames]
feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i],
dtype=x.dtype, device=x.device)
u = self.encoder_unmasked_dims[i]
feature_mask[:, :, u:] *= frame_mask
feature_masks.append(feature_mask)
# frame_mask0 shape: (num_frames0, batch_size, 1)
frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape(
num_frames1 * ds, batch_size, 1)[:num_frames0]
feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0],
dtype=x.dtype, device=x.device)
feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0
return feature_mask0, feature_mask1
return feature_masks
def forward(
@ -204,7 +206,7 @@ class Zipformer(EncoderInterface):
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, d_model)
- embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1])
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
@ -219,18 +221,14 @@ class Zipformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
feature_mask0, feature_mask1 = self.get_feature_mask(x)
feature_masks = self.get_feature_masks(x)
# x1:
x1 = self.encoder1(
x, feature_mask=feature_mask0, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[0]
for i, module in enumerate(self.encoders):
ds = self.zipformer_downsampling_factors[i]
x = module(x,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
x2 = self.encoder2(
x1, feature_mask=feature_mask1, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[1]
x = self.out_combiner(x1, x2)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -556,8 +554,8 @@ class ZipformerEncoder(nn.Module):
class DownsampledZipformerEncoder(nn.Module):
r"""
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output
so that the output has the same shape as the input.
after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
"""
def __init__(self,
encoder: nn.Module,
@ -569,6 +567,67 @@ class DownsampledZipformerEncoder(nn.Module):
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.out_combiner = SimpleCombiner(input_dim,
output_dim)
def forward(self,
src: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer. feature_mask is expected to be already downsampled by
self.downsample_factor.
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional). Should
be downsampled already.
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds,::ds]
src = self.encoder(
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
src = src[:src_orig.shape[0]]
return self.out_combiner(src_orig, src)
class DownsamplingZipformerEncoder(nn.Module):
r"""
DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input
by a specified factor before feeding it to the zipformer layers.
"""
def __init__(self,
encoder: nn.Module,
input_dim: int,
output_dim: int,
downsample: int):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
def forward(self,
@ -608,10 +667,6 @@ class DownsampledZipformerEncoder(nn.Module):
src = self.encoder(
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
src = src[:src_orig.shape[0]]
return src
@ -1642,7 +1697,7 @@ def _test_zipformer_main():
# Just make sure the forward pass runs.
c = Zipformer(
num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4)
num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4)
)
batch_size = 5
seq_len = 20