mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify Conv2dSubsampling, removing all but one ConvNext layer
This commit is contained in:
parent
65f15c9d14
commit
9e4b84f374
@ -1944,7 +1944,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# training. (The second one is necessary to stop its bias from getting
|
# training. (The second one is necessary to stop its bias from getting
|
||||||
# a too-large gradient).
|
# a too-large gradient).
|
||||||
|
|
||||||
self.conv1 = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
@ -1967,18 +1967,6 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
channel_dim=1,
|
channel_dim=1,
|
||||||
max_abs=4.0),
|
max_abs=4.0),
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
)
|
|
||||||
|
|
||||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels, kernel_size=(5, 7)),
|
|
||||||
ConvNeXt(layer2_channels, kernel_size=(5, 7)),
|
|
||||||
BasicNorm(layer2_channels,
|
|
||||||
channel_dim=1))
|
|
||||||
|
|
||||||
|
|
||||||
cur_width = (in_channels - 1) // 2
|
|
||||||
|
|
||||||
|
|
||||||
self.conv2 = nn.Sequential(
|
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer2_channels,
|
in_channels=layer2_channels,
|
||||||
out_channels=layer3_channels,
|
out_channels=layer3_channels,
|
||||||
@ -1991,16 +1979,13 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
SwooshR(),
|
SwooshR(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels, kernel_size=(7, 7)),
|
cur_width = (in_channels - 1) // 2
|
||||||
ConvNeXt(layer3_channels, kernel_size=(7, 7)),
|
|
||||||
ConvNeXt(layer3_channels, kernel_size=(7, 7)))
|
# just one convnext layer
|
||||||
|
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||||
|
|
||||||
out_width = (((in_channels - 1) // 2) - 1) // 2
|
out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
|
||||||
self.scale = nn.Parameter(torch.ones(out_width * layer3_channels))
|
|
||||||
self.scale_max = 1.0
|
|
||||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
|
||||||
|
|
||||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||||
# use a larger than normal grad_scale on this whitening module; there is
|
# use a larger than normal grad_scale on this whitening module; there is
|
||||||
# only one such module, so there is not a concern about adding together
|
# only one such module, so there is not a concern about adding together
|
||||||
@ -2031,13 +2016,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||||
# gradients.
|
# gradients.
|
||||||
x = self.conv1(x)
|
x = self.conv(x)
|
||||||
x = self.convnext1(x)
|
x = self.convnext(x)
|
||||||
|
|
||||||
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = self.convnext2(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
@ -2045,13 +2025,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||||
|
|
||||||
x = x * limit_param_value(self.scale,
|
|
||||||
min=float(self.scale_min),
|
|
||||||
max=float(self.scale_max),
|
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
x = self.out_whiten(x)
|
x = self.out_whiten(x)
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user