Make BasicNorm after convnext1 operate over all frequency bins.
This commit is contained in:
parent
dd7257f01b
commit
180c440e63
@ -1812,9 +1812,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
)
|
||||
|
||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels),
|
||||
BasicNorm(layer2_channels,
|
||||
channel_dim=1))
|
||||
ConvNeXt(layer2_channels))
|
||||
|
||||
|
||||
cur_width = (in_channels - 1) // 2
|
||||
|
||||
self.norm1 = BasicNorm(layer2_channels * cur_width,
|
||||
channel_dim=-1)
|
||||
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
@ -1834,13 +1838,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels))
|
||||
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.scale = nn.Parameter(torch.ones(out_height * layer3_channels))
|
||||
self.scale = nn.Parameter(torch.ones(out_width * layer3_channels))
|
||||
self.scale_max = 1.0
|
||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||
|
||||
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||
|
||||
self.out_norm = BasicNorm(out_channels)
|
||||
self.dropout = Dropout2(dropout)
|
||||
@ -1863,6 +1867,12 @@ class Conv2dSubsampling(nn.Module):
|
||||
# gradients.
|
||||
x = self.conv1(x)
|
||||
x = self.convnext1(x)
|
||||
|
||||
(batch_size, layer2_channels, num_frames, cur_width) = x.shape
|
||||
x = x.permute(0, 2, 1, 3).reshape(batch_size, num_frames, layer2_channels * cur_width)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(batch_size, num_frames, layer2_channels, cur_width).permute(0, 2, 1, 3)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.convnext2(x)
|
||||
|
||||
@ -1871,7 +1881,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||
|
||||
x = x * limit_param_value(self.scale,
|
||||
min=float(self.scale_min),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user