Simplify Conv2dSubsampling, removing all but one ConvNext layer

This commit is contained in:
Daniel Povey 2023-01-12 20:14:51 +08:00
parent 65f15c9d14
commit 9e4b84f374

View File

@ -1944,7 +1944,7 @@ class Conv2dSubsampling(nn.Module):
# training. (The second one is necessary to stop its bias from getting # training. (The second one is necessary to stop its bias from getting
# a too-large gradient). # a too-large gradient).
self.conv1 = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=1, in_channels=1,
out_channels=layer1_channels, out_channels=layer1_channels,
@ -1964,21 +1964,9 @@ class Conv2dSubsampling(nn.Module):
padding=0, padding=0,
), ),
Balancer(layer2_channels, Balancer(layer2_channels,
channel_dim=1, channel_dim=1,
max_abs=4.0), max_abs=4.0),
SwooshR(), SwooshR(),
)
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels, kernel_size=(5, 7)),
ConvNeXt(layer2_channels, kernel_size=(5, 7)),
BasicNorm(layer2_channels,
channel_dim=1))
cur_width = (in_channels - 1) // 2
self.conv2 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=layer2_channels, in_channels=layer2_channels,
out_channels=layer3_channels, out_channels=layer3_channels,
@ -1986,21 +1974,18 @@ class Conv2dSubsampling(nn.Module):
stride=(1, 2), # (time, freq) stride=(1, 2), # (time, freq)
), ),
Balancer(layer3_channels, Balancer(layer3_channels,
channel_dim=1, channel_dim=1,
max_abs=4.0), max_abs=4.0),
SwooshR(), SwooshR(),
) )
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels, kernel_size=(7, 7)), cur_width = (in_channels - 1) // 2
ConvNeXt(layer3_channels, kernel_size=(7, 7)),
ConvNeXt(layer3_channels, kernel_size=(7, 7))) # just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
out_width = (((in_channels - 1) // 2) - 1) // 2 out_width = (((in_channels - 1) // 2) - 1) // 2
self.scale = nn.Parameter(torch.ones(out_width * layer3_channels))
self.scale_max = 1.0
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
self.out = nn.Linear(out_width * layer3_channels, out_channels) self.out = nn.Linear(out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is # use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together # only one such module, so there is not a concern about adding together
@ -2031,13 +2016,8 @@ class Conv2dSubsampling(nn.Module):
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients. # gradients.
x = self.conv1(x) x = self.conv(x)
x = self.convnext1(x) x = self.convnext(x)
x = self.conv2(x)
x = self.convnext2(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size() b, c, t, f = x.size()
@ -2045,13 +2025,8 @@ class Conv2dSubsampling(nn.Module):
x = x.transpose(1, 2).reshape(b, t, c * f) x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = x * limit_param_value(self.scale,
min=float(self.scale_min),
max=float(self.scale_max),
training=self.training)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out(x) x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x) x = self.out_whiten(x)
x = self.out_norm(x) x = self.out_norm(x)
x = self.dropout(x) x = self.dropout(x)