Refactor zipformer for more flexibility so we can change number of encoder layers.
This commit is contained in:
parent
e592a920b4
commit
ed1b4d5e5d
@ -127,19 +127,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Unmasked dimension in the encoder, relates to augmentation during training. "
|
||||
"--encoder-unmasked-dims",
|
||||
type=str,
|
||||
default="256,256",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
|
||||
" worse."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--zipformer-subsampling-factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||
"--zipformer-downsampling-factors",
|
||||
type=str,
|
||||
default="1,2",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -437,10 +437,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
zipformer_subsampling_factor=params.zipformer_subsampling_factor,
|
||||
d_model=to_int_tuple(params.encoder_dims),
|
||||
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
|
||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
encoder_unmasked_dim=params.encoder_unmasked_dim,
|
||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||
nhead=to_int_tuple(params.nhead),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||
|
||||
@ -64,9 +64,10 @@ class Zipformer(EncoderInterface):
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
zipformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (384, 384),
|
||||
encoder_dims: Tuple[int] = (384, 384),
|
||||
attention_dim: Tuple[int] = (256, 256),
|
||||
encoder_unmasked_dim: int = 256,
|
||||
encoder_unmasked_dims: Tuple[int] = (256, 256),
|
||||
zipformer_downsampling_factors: Tuple[int] = (1, 2),
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
@ -78,74 +79,69 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim
|
||||
assert 0 < d_model[0] <= d_model[1]
|
||||
self.d_model = d_model
|
||||
self.zipformer_subsampling_factor = zipformer_subsampling_factor
|
||||
self.encoder_unmasked_dims = encoder_unmasked_dims
|
||||
assert 0 < encoder_dims[0] <= encoder_dims[1]
|
||||
self.encoder_dims = encoder_dims
|
||||
self.encoder_unmasked_dims = encoder_unmasked_dims
|
||||
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
||||
|
||||
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
|
||||
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
||||
assert u <= d
|
||||
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# to the shape (N, T//subsampling_factor, encoder_dims).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
|
||||
# (2) embedding: num_features -> encoder_dims
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0],
|
||||
dropout=dropout)
|
||||
|
||||
encoder_layer1 = ZipformerEncoderLayer(
|
||||
d_model[0],
|
||||
attention_dim[0],
|
||||
nhead[0],
|
||||
feedforward_dim[0],
|
||||
dropout,
|
||||
cnn_module_kernel[0],
|
||||
|
||||
)
|
||||
# for the first third of the warmup period, we let the Conv2dSubsampling
|
||||
# layer learn something. then start warmup up the first and then the second
|
||||
# encoder.
|
||||
self.encoder1 = ZipformerEncoder(
|
||||
encoder_layer1,
|
||||
num_encoder_layers[0],
|
||||
dropout,
|
||||
warmup_begin=warmup_batches / 3,
|
||||
warmup_end=warmup_batches * 2 / 3,
|
||||
)
|
||||
encoder_layer2 = ZipformerEncoderLayer(
|
||||
d_model[1],
|
||||
attention_dim[1],
|
||||
nhead[1],
|
||||
feedforward_dim[1],
|
||||
dropout,
|
||||
cnn_module_kernel[1],
|
||||
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
|
||||
encoders = []
|
||||
|
||||
)
|
||||
self.encoder2 = DownsampledZipformerEncoder(
|
||||
ZipformerEncoder(
|
||||
encoder_layer2,
|
||||
num_encoder_layers[1],
|
||||
num_encoders = len(encoder_dims)
|
||||
for i in range(num_encoders):
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
encoder_dims[i],
|
||||
attention_dim[i],
|
||||
nhead[i],
|
||||
feedforward_dim[i],
|
||||
dropout,
|
||||
warmup_begin=warmup_batches * 2 / 3,
|
||||
warmup_end=warmup_batches,
|
||||
),
|
||||
input_dim=d_model[0],
|
||||
output_dim=d_model[1],
|
||||
downsample=zipformer_subsampling_factor,
|
||||
)
|
||||
cnn_module_kernel[i],
|
||||
)
|
||||
|
||||
self.out_combiner = SimpleCombiner(d_model[0],
|
||||
d_model[1])
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
encoder = ZipformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
|
||||
)
|
||||
|
||||
def get_feature_mask(
|
||||
if zipformer_downsampling_factors[i] != 1:
|
||||
assert i > 0, "First zipformer layer cannot use downsampling"
|
||||
encoder = DownsampledZipformerEncoder(
|
||||
encoder,
|
||||
input_dim=encoder_dims[i-1],
|
||||
output_dim=encoder_dims[i],
|
||||
downsample=zipformer_downsampling_factors[i],
|
||||
)
|
||||
encoders.append(encoder)
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
|
||||
def get_feature_masks(
|
||||
self,
|
||||
x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]:
|
||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||
"""
|
||||
In eval mode, returns 1.0; in training mode, returns two randomized feature masks
|
||||
for the 1st and second encoders (which may run at different frame rates).
|
||||
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
||||
randomized feature masks, one per encoder.
|
||||
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
||||
some supplied number, e.g. >256, so in effect on those frames we are using
|
||||
a smaller encoer dim.
|
||||
@ -156,40 +152,46 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
(num_frames, batch_size, d_model0)
|
||||
(num_frames, batch_size, encoder_dims0)
|
||||
"""
|
||||
num_encoders = len(self.encoder_dims)
|
||||
if not self.training:
|
||||
return 1.0, 1.0
|
||||
return [ 1.0 ] * num_encoders
|
||||
|
||||
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
||||
|
||||
|
||||
assert self.encoder_dims[0] == _encoder_dims0
|
||||
|
||||
max_downsampling_factor = max(self.zipformer_downsampling_factors)
|
||||
|
||||
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
||||
|
||||
d_model0, d_model1 = self.d_model
|
||||
(num_frames0, batch_size, _d_model0) = x.shape
|
||||
assert d_model0 == _d_model0
|
||||
ds = self.zipformer_subsampling_factor
|
||||
num_frames1 = ((num_frames0 + ds - 1) // ds)
|
||||
|
||||
# on this proportion of the frames, drop out the extra features above
|
||||
# self.encoder_unmasked_dim.
|
||||
feature_mask_dropout_prob = 0.15
|
||||
|
||||
# frame_mask1 shape: (num_frames1, batch_size, 1)
|
||||
frame_mask1 = (torch.rand(num_frames1, batch_size, 1, device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
# frame_mask_max shape: (num_frames_max, batch_size, 1)
|
||||
frame_mask_max = (torch.rand(num_frames_max, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
feature_masks = []
|
||||
for i in range(num_encoders):
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
upsample_factor = (max_downsampling_factor // ds)
|
||||
|
||||
feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1],
|
||||
dtype=x.dtype, device=x.device)
|
||||
feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1
|
||||
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
||||
batch_size, 1)
|
||||
.reshape(num_frames_max * upsample_factor, batch_size, 1))
|
||||
num_frames = (num_frames0 + ds - 1) // ds
|
||||
frame_mask = frame_mask[:num_frames]
|
||||
feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i],
|
||||
dtype=x.dtype, device=x.device)
|
||||
u = self.encoder_unmasked_dims[i]
|
||||
feature_mask[:, :, u:] *= frame_mask
|
||||
feature_masks.append(feature_mask)
|
||||
|
||||
|
||||
# frame_mask0 shape: (num_frames0, batch_size, 1)
|
||||
frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape(
|
||||
num_frames1 * ds, batch_size, 1)[:num_frames0]
|
||||
|
||||
feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0],
|
||||
dtype=x.dtype, device=x.device)
|
||||
feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0
|
||||
|
||||
return feature_mask0, feature_mask1
|
||||
return feature_masks
|
||||
|
||||
|
||||
def forward(
|
||||
@ -204,7 +206,7 @@ class Zipformer(EncoderInterface):
|
||||
`x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||
- embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1])
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
@ -219,18 +221,14 @@ class Zipformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
feature_mask0, feature_mask1 = self.get_feature_mask(x)
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
# x1:
|
||||
x1 = self.encoder1(
|
||||
x, feature_mask=feature_mask0, src_key_padding_mask=mask,
|
||||
) # (T, N, C) where C == d_model[0]
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.zipformer_downsampling_factors[i]
|
||||
x = module(x,
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||
|
||||
x2 = self.encoder2(
|
||||
x1, feature_mask=feature_mask1, src_key_padding_mask=mask,
|
||||
) # (T, N, C) where C == d_model[1]
|
||||
|
||||
x = self.out_combiner(x1, x2)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
@ -556,8 +554,8 @@ class ZipformerEncoder(nn.Module):
|
||||
class DownsampledZipformerEncoder(nn.Module):
|
||||
r"""
|
||||
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output
|
||||
so that the output has the same shape as the input.
|
||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
@ -569,6 +567,67 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||
self.out_combiner = SimpleCombiner(input_dim,
|
||||
output_dim)
|
||||
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer. feature_mask is expected to be already downsampled by
|
||||
self.downsample_factor.
|
||||
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
||||
this, if we are to support it. Won't work correctly yet.
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional). Should
|
||||
be downsampled already.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
Returns: output of shape (S, N, F) where F is the number of output features
|
||||
(output_dim to constructor)
|
||||
"""
|
||||
src_orig = src
|
||||
src = self.downsample(src)
|
||||
ds = self.downsample_factor
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
|
||||
src = self.encoder(
|
||||
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
|
||||
)
|
||||
src = self.upsample(src)
|
||||
# remove any extra frames that are not a multiple of downsample_factor
|
||||
src = src[:src_orig.shape[0]]
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
|
||||
|
||||
class DownsamplingZipformerEncoder(nn.Module):
|
||||
r"""
|
||||
DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input
|
||||
by a specified factor before feeding it to the zipformer layers.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
downsample: int):
|
||||
super(DownsampledZipformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
||||
self.encoder = encoder
|
||||
|
||||
|
||||
def forward(self,
|
||||
@ -608,10 +667,6 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
src = self.encoder(
|
||||
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
|
||||
)
|
||||
src = self.upsample(src)
|
||||
# remove any extra frames that are not a multiple of downsample_factor
|
||||
src = src[:src_orig.shape[0]]
|
||||
|
||||
return src
|
||||
|
||||
|
||||
@ -1642,7 +1697,7 @@ def _test_zipformer_main():
|
||||
# Just make sure the forward pass runs.
|
||||
|
||||
c = Zipformer(
|
||||
num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4)
|
||||
num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4)
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user