Implement Nextformer-style frontend

This commit is contained in:
Daniel Povey 2022-12-15 21:48:32 +08:00
parent 37a8c30136
commit 076b18db60

View File

@ -735,61 +735,6 @@ class DownsampledZipformerEncoder(nn.Module):
return self.out_combiner(src_orig, src) return self.out_combiner(src_orig, src)
class DownsamplingZipformerEncoder(nn.Module):
r"""
DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input
by a specified factor before feeding it to the zipformer layers.
"""
def __init__(self,
encoder: nn.Module,
input_dim: int,
output_dim: int,
downsample: int):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
def forward(self,
src: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer. feature_mask is expected to be already downsampled by
self.downsample_factor.
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
src = self.encoder(
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
)
return src
class AttentionDownsample(torch.nn.Module): class AttentionDownsample(torch.nn.Module):
""" """
@ -1734,6 +1679,71 @@ class ScalarMultiply(nn.Module):
def forward(self, x): def forward(self, x):
return x * self.scale return x * self.scale
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(self,
channels: int,
hidden_ratio: int = 4,
layerdrop_prob: FloatLike = None):
super().__init__()
kernel_size = 7
pad = (kernel_size - 1) // 2
hidden_channels = channels * hidden_ratio
if layerdrop_prob is None:
layerdrop_prob = ScheduledFloat((0.0, 0.1), (16000.0, 0.01))
self.layerdrop_prob = layerdrop_prob
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=7,
padding=(3, 3))
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels,
out_channels=hidden_channels,
kernel_size=1)
self.hidden_balancer = ActivationBalancer(hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
min_prob=0.25)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01)
def forward(self, x: Tensor) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
if (not self.training) or torch.jit.is_scripting() or random.random() < float(self.layerdrop_prob):
return x
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
return bypass + x
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length). """Convolutional 2D subsampling (to 1/2 length).
@ -1752,7 +1762,6 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
bottleneck_channels: int = 64,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
) -> None: ) -> None:
""" """
@ -1778,7 +1787,7 @@ class Conv2dSubsampling(nn.Module):
# training. (The second one is necessary to stop its bias from getting # training. (The second one is necessary to stop its bias from getting
# a too-large gradient). # a too-large gradient).
self.conv = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=1, in_channels=1,
out_channels=layer1_channels, out_channels=layer1_channels,
@ -1797,21 +1806,29 @@ class Conv2dSubsampling(nn.Module):
stride=2, stride=2,
padding=0, padding=0,
), ),
ActivationBalancer(layer2_channels,
channel_dim=1,
max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
ActivationBalancer(layer3_channels,
channel_dim=1,
max_abs=4.0),
SwooshR(),
) )
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels),
BasicNorm(layer2_channels,
channel_dim=1))
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
))
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels),
BasicNorm(layer3_channels,
channel_dim=1))
out_height = (((in_channels - 1) // 2) - 1) // 2 out_height = (((in_channels - 1) // 2) - 1) // 2
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels)) self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
@ -1839,7 +1856,12 @@ class Conv2dSubsampling(nn.Module):
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients. # gradients.
x = self.conv(x) x = self.conv1(x)
x = self.convnext1(x)
x = self.conv2(x)
x = self.convnext2(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size() b, c, t, f = x.size()