mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement Nextformer-style frontend
This commit is contained in:
parent
37a8c30136
commit
076b18db60
@ -735,61 +735,6 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
return self.out_combiner(src_orig, src)
|
||||
|
||||
|
||||
class DownsamplingZipformerEncoder(nn.Module):
|
||||
r"""
|
||||
DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input
|
||||
by a specified factor before feeding it to the zipformer layers.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
downsample: int):
|
||||
super(DownsampledZipformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
||||
self.encoder = encoder
|
||||
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer. feature_mask is expected to be already downsampled by
|
||||
self.downsample_factor.
|
||||
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
||||
this, if we are to support it. Won't work correctly yet.
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
Returns: output of shape (S, N, F) where F is the number of output features
|
||||
(output_dim to constructor)
|
||||
"""
|
||||
src_orig = src
|
||||
src = self.downsample(src)
|
||||
ds = self.downsample_factor
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
if src_key_padding_mask is not None:
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
|
||||
src = self.encoder(
|
||||
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
|
||||
)
|
||||
return src
|
||||
|
||||
|
||||
class AttentionDownsample(torch.nn.Module):
|
||||
"""
|
||||
@ -1734,6 +1679,71 @@ class ScalarMultiply(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
"""
|
||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||
"""
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 4,
|
||||
layerdrop_prob: FloatLike = None):
|
||||
super().__init__()
|
||||
kernel_size = 7
|
||||
pad = (kernel_size - 1) // 2
|
||||
hidden_channels = channels * hidden_ratio
|
||||
if layerdrop_prob is None:
|
||||
layerdrop_prob = ScheduledFloat((0.0, 0.1), (16000.0, 0.01))
|
||||
self.layerdrop_prob = layerdrop_prob
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=7,
|
||||
padding=(3, 3))
|
||||
|
||||
self.pointwise_conv1 = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=hidden_channels,
|
||||
kernel_size=1)
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(hidden_channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
min_prob=0.25)
|
||||
self.activation = SwooshL()
|
||||
self.pointwise_conv2 = ScaledConv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
initial_scale=0.01)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
The returned value has the same shape as x.
|
||||
"""
|
||||
if (not self.training) or torch.jit.is_scripting() or random.random() < float(self.layerdrop_prob):
|
||||
return x
|
||||
|
||||
bypass = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
return bypass + x
|
||||
|
||||
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
@ -1752,7 +1762,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
bottleneck_channels: int = 64,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1778,7 +1787,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# training. (The second one is necessary to stop its bias from getting
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
@ -1797,21 +1806,29 @@ class Conv2dSubsampling(nn.Module):
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
ActivationBalancer(layer2_channels,
|
||||
channel_dim=1,
|
||||
max_abs=4.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
ActivationBalancer(layer3_channels,
|
||||
channel_dim=1,
|
||||
max_abs=4.0),
|
||||
SwooshR(),
|
||||
)
|
||||
|
||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels),
|
||||
BasicNorm(layer2_channels,
|
||||
channel_dim=1))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
))
|
||||
|
||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
BasicNorm(layer3_channels,
|
||||
channel_dim=1))
|
||||
|
||||
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
||||
@ -1839,7 +1856,12 @@ class Conv2dSubsampling(nn.Module):
|
||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||
# gradients.
|
||||
x = self.conv(x)
|
||||
x = self.conv1(x)
|
||||
x = self.convnext1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.convnext2(x)
|
||||
|
||||
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user