From 05b5e78d8f2298cf6b4b757a620df099dfc0841d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 15:55:11 +0800 Subject: [PATCH] Add norm+balancer to VggSubsampling --- .../ASR/pruned_transducer_stateless2/subsampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index 51b08e072..c2da23adc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -158,6 +158,12 @@ class VggSubsampling(nn.Module): self.out = nn.Linear( block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim ) + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -173,4 +179,6 @@ class VggSubsampling(nn.Module): x = self.layers(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) return x