Refactor how the downsampling is done so that it happens later, but the 1st encoder stack still operates after a subsampling of 2.

This commit is contained in:
Daniel Povey 2022-10-28 19:20:21 +08:00
parent 0a89f51dc9
commit d7d5188bd9
2 changed files with 24 additions and 22 deletions

View File

@ -138,7 +138,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--zipformer-downsampling-factors",
type=str,
default="1,2,4",
default="2,4,8",
help="Downsampling factor for each stack of encoder layers.",
)
@ -428,7 +428,7 @@ def get_params() -> AttributeDict:
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4,
"subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 2000,
"env_info": get_env_info(),
}
@ -443,7 +443,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
return tuple(map(int, s.split(',')))
encoder = Zipformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),

View File

@ -47,7 +47,6 @@ class Zipformer(EncoderInterface):
"""
Args:
num_features (int): Number of input features
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model: (int,int): embedding dimension of 2 encoder stacks
attention_dim: (int,int): attention dimension of 2 encoder stacks
nhead (int, int): number of heads
@ -62,12 +61,11 @@ class Zipformer(EncoderInterface):
def __init__(
self,
num_features: int,
subsampling_factor: int = 4,
zipformer_subsampling_factor: int = 4,
output_downsampling_factor: int = 2,
encoder_dims: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dims: Tuple[int] = (256, 256),
zipformer_downsampling_factors: Tuple[int] = (1, 2),
zipformer_downsampling_factors: Tuple[int] = (2, 4),
nhead: Tuple[int] = (8, 8),
feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
@ -78,23 +76,20 @@ class Zipformer(EncoderInterface):
super(Zipformer, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
self.encoder_unmasked_dims = encoder_unmasked_dims
assert 0 < encoder_dims[0] <= encoder_dims[1]
self.encoder_dims = encoder_dims
self.encoder_unmasked_dims = encoder_unmasked_dims
self.zipformer_downsampling_factors = zipformer_downsampling_factors
self.output_downsampling_factor = output_downsampling_factor
for u,d in zip(encoder_unmasked_dims, encoder_dims):
assert u <= d
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (1) subsampling: T -> T//2
# (2) embedding: num_features -> encoder_dims
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0],
dropout=dropout)
@ -125,10 +120,9 @@ class Zipformer(EncoderInterface):
)
if zipformer_downsampling_factors[i] != 1:
assert i > 0, "First zipformer layer cannot use downsampling"
encoder = DownsampledZipformerEncoder(
encoder,
input_dim=encoder_dims[i-1],
input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0],
output_dim=encoder_dims[i],
downsample=zipformer_downsampling_factors[i],
)
@ -136,6 +130,10 @@ class Zipformer(EncoderInterface):
self.encoders = nn.ModuleList(encoders)
self.downsample_output = AttentionDownsample(encoder_dims[-1],
encoder_dims[-1],
downsample=output_downsampling_factor)
def get_feature_masks(
self,
x: torch.Tensor) -> List[Union[float, Tensor]]:
@ -216,8 +214,7 @@ class Zipformer(EncoderInterface):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
lengths = (x_lens - 7) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
@ -229,6 +226,10 @@ class Zipformer(EncoderInterface):
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
lengths = (x_lens + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -1468,7 +1469,7 @@ class Conv2dSubsampling(nn.Module):
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
@ -1489,7 +1490,7 @@ class Conv2dSubsampling(nn.Module):
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
@ -1503,7 +1504,7 @@ class Conv2dSubsampling(nn.Module):
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
padding=(0, 1), # (time, freq)
),
ActivationBalancer(layer1_channels,
channel_dim=1),
@ -1513,6 +1514,7 @@ class Conv2dSubsampling(nn.Module):
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
ActivationBalancer(layer2_channels,
channel_dim=1),
@ -1521,13 +1523,13 @@ class Conv2dSubsampling(nn.Module):
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
stride=(1, 2), # (time, freq)
),
ActivationBalancer(layer3_channels,
channel_dim=1),
DoubleSwish(),
)
out_height = (((in_channels - 1) // 2 - 1) // 2)
out_height = (((in_channels - 1) // 2) - 1) // 2
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
self.dropout = nn.Dropout(dropout)
@ -1545,7 +1547,7 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)