diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4cd98f384..ff86165b7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -466,8 +466,8 @@ class BasicNorm(torch.nn.Module): channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, learn_eps: bool = True, - eps_min: float = -2.0, - eps_max: float = 2.0, + eps_min: float = -3.0, + eps_max: float = 3.0, ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels @@ -487,8 +487,8 @@ class BasicNorm(torch.nn.Module): eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max) eps = eps.exp() scales = ( - (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) / - (1.0 + eps) + (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) + # / (1.0 + eps) ) ** -0.5 return x * scales diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9c4b1251d..974bb5d28 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -573,8 +573,9 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) src = self.balancer(src) + src = self.norm_final(src) - delta = self.norm_final(src - src_orig) + delta = src - src_orig src = src_orig + delta * self.get_bypass_scale(src.shape[1]) src = self.whiten(src) @@ -1820,14 +1821,13 @@ class Conv2dSubsampling(nn.Module): ) self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), - ConvNeXt(layer2_channels)) + ConvNeXt(layer2_channels), + BasicNorm(layer2_channels, + channel_dim=1)) cur_width = (in_channels - 1) // 2 - self.norm1 = BasicNorm(layer2_channels * cur_width, - channel_dim=-1) - self.conv2 = nn.Sequential( nn.Conv2d( @@ -1883,10 +1883,6 @@ class Conv2dSubsampling(nn.Module): x = self.conv1(x) x = self.convnext1(x) - (batch_size, layer2_channels, num_frames, cur_width) = x.shape - x = x.permute(0, 2, 1, 3).reshape(batch_size, num_frames, layer2_channels * cur_width) - x = self.norm1(x) - x = x.reshape(batch_size, num_frames, layer2_channels, cur_width).permute(0, 2, 1, 3) x = self.conv2(x) x = self.convnext2(x)