From a2dbce2a9acab425675b408bc95e05acc39e7810 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Nov 2022 13:02:54 +0800 Subject: [PATCH] Add Whiten module, with whitening_limit=10.0, at output of ModifiedSEModule --- .../ASR/pruned_transducer_stateless7/zipformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1c181b1ac..0d4f44d14 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1463,12 +1463,19 @@ class ModifiedSEModule(nn.Module): max_factor=0.01, min_prob=0.2, ) + #self.bottleneck_norm = BasicNorm(bottleneck_dim) + self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model) - self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes. self.out_proj = ScaledLinear(d_model, d_model, bias=False, initial_scale=0.1) + self.out_whiten = Whiten(num_groups=1, + whitening_limit=10.0, + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward(self, @@ -1497,16 +1504,16 @@ class ModifiedSEModule(nn.Module): squeezed = self.activation(squeezed) squeezed = self.to_bottleneck_proj(squeezed) squeezed = self.bottleneck_balancer(squeezed) + #squeezed = self.bottleneck_norm(squeezed) squeezed = self.from_bottleneck_proj(squeezed) if random.random() < 0.05: # to stop a hopefully-unlikely failure mode where the inputs to the sigmoid # get too large and the grads get mostly too small. squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04) - scales = self.sigmoid(squeezed) x = self.in_proj(x) x = x * squeezed - return self.out_proj(x) + return self.out_whiten(self.out_proj(x)) class FeedforwardModule(nn.Module):