Simplify Conv2dSubsampling, removing all but one ConvNext layer

This commit is contained in:
Daniel Povey 2023-01-12 20:14:51 +08:00
parent 65f15c9d14
commit 9e4b84f374

View File

@ -1944,7 +1944,7 @@ class Conv2dSubsampling(nn.Module):
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv1 = nn.Sequential(
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
@ -1967,18 +1967,6 @@ class Conv2dSubsampling(nn.Module):
channel_dim=1,
max_abs=4.0),
SwooshR(),
)
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels, kernel_size=(5, 7)),
ConvNeXt(layer2_channels, kernel_size=(5, 7)),
BasicNorm(layer2_channels,
channel_dim=1))
cur_width = (in_channels - 1) // 2
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
@ -1991,16 +1979,13 @@ class Conv2dSubsampling(nn.Module):
SwooshR(),
)
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels, kernel_size=(7, 7)),
ConvNeXt(layer3_channels, kernel_size=(7, 7)),
ConvNeXt(layer3_channels, kernel_size=(7, 7)))
cur_width = (in_channels - 1) // 2
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
out_width = (((in_channels - 1) // 2) - 1) // 2
self.scale = nn.Parameter(torch.ones(out_width * layer3_channels))
self.scale_max = 1.0
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
self.out = nn.Linear(out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
@ -2031,13 +2016,8 @@ class Conv2dSubsampling(nn.Module):
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv1(x)
x = self.convnext1(x)
x = self.conv2(x)
x = self.convnext2(x)
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
@ -2045,13 +2025,8 @@ class Conv2dSubsampling(nn.Module):
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = x * limit_param_value(self.scale,
min=float(self.scale_min),
max=float(self.scale_max),
training=self.training)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)