diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index 51b08e072..c2da23adc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -158,6 +158,12 @@ class VggSubsampling(nn.Module): self.out = nn.Linear( block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim ) + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -173,4 +179,6 @@ class VggSubsampling(nn.Module): x = self.layers(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) return x