From 180c440e63908f11158f4a472e8e27a0b8270242 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Dec 2022 17:25:30 +0800 Subject: [PATCH] Make BasicNorm after convnext1 operate over all frequency bins. --- .../pruned_transducer_stateless7/zipformer.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index dec6142b8..75ea8a5d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1812,9 +1812,13 @@ class Conv2dSubsampling(nn.Module): ) self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), - ConvNeXt(layer2_channels), - BasicNorm(layer2_channels, - channel_dim=1)) + ConvNeXt(layer2_channels)) + + + cur_width = (in_channels - 1) // 2 + + self.norm1 = BasicNorm(layer2_channels * cur_width, + channel_dim=-1) self.conv2 = nn.Sequential( @@ -1834,13 +1838,13 @@ class Conv2dSubsampling(nn.Module): ConvNeXt(layer3_channels), ConvNeXt(layer3_channels)) - out_height = (((in_channels - 1) // 2) - 1) // 2 + out_width = (((in_channels - 1) // 2) - 1) // 2 - self.scale = nn.Parameter(torch.ones(out_height * layer3_channels)) + self.scale = nn.Parameter(torch.ones(out_width * layer3_channels)) self.scale_max = 1.0 self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) - self.out = nn.Linear(out_height * layer3_channels, out_channels) + self.out = nn.Linear(out_width * layer3_channels, out_channels) self.out_norm = BasicNorm(out_channels) self.dropout = Dropout2(dropout) @@ -1863,6 +1867,12 @@ class Conv2dSubsampling(nn.Module): # gradients. x = self.conv1(x) x = self.convnext1(x) + + (batch_size, layer2_channels, num_frames, cur_width) = x.shape + x = x.permute(0, 2, 1, 3).reshape(batch_size, num_frames, layer2_channels * cur_width) + x = self.norm1(x) + x = x.reshape(batch_size, num_frames, layer2_channels, cur_width).permute(0, 2, 1, 3) + x = self.conv2(x) x = self.convnext2(x) @@ -1871,7 +1881,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels)) + # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) x = x * limit_param_value(self.scale, min=float(self.scale_min),