diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e4385f87a..25676801d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1712,10 +1712,13 @@ class ConvNeXt(nn.Module): self.out_balancer = ActivationBalancer( channels, channel_dim=1, - min_positive=0.5, max_positive=0.5, + min_positive=0.4, max_positive=0.6, min_abs=0.25, max_abs=6.0, ) - + self.out_whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def forward(self, x: Tensor) -> Tensor: @@ -1739,6 +1742,10 @@ class ConvNeXt(nn.Module): x = bypass + x x = self.out_balancer(x) + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + return x @@ -1845,13 +1852,13 @@ class Conv2dSubsampling(nn.Module): self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) self.out = nn.Linear(out_width * layer3_channels, out_channels) - # use a much larger than normal grad_scale on this whitening module; - # there is only one such module, so there is not a concern about adding - # together many copies of this extra gradient term. + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0), prob=(0.025, 0.25), - grad_scale=0.05) + grad_scale=0.02) self.out_norm = BasicNorm(out_channels) self.dropout = Dropout2(dropout)