diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 1c181b1ac..0d4f44d14 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1463,12 +1463,19 @@ class ModifiedSEModule(nn.Module): max_factor=0.01, min_prob=0.2, ) + #self.bottleneck_norm = BasicNorm(bottleneck_dim) + self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model) - self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes. self.out_proj = ScaledLinear(d_model, d_model, bias=False, initial_scale=0.1) + self.out_whiten = Whiten(num_groups=1, + whitening_limit=10.0, + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward(self, @@ -1497,16 +1504,16 @@ class ModifiedSEModule(nn.Module): squeezed = self.activation(squeezed) squeezed = self.to_bottleneck_proj(squeezed) squeezed = self.bottleneck_balancer(squeezed) + #squeezed = self.bottleneck_norm(squeezed) squeezed = self.from_bottleneck_proj(squeezed) if random.random() < 0.05: # to stop a hopefully-unlikely failure mode where the inputs to the sigmoid # get too large and the grads get mostly too small. squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04) - scales = self.sigmoid(squeezed) x = self.in_proj(x) x = x * squeezed - return self.out_proj(x) + return self.out_whiten(self.out_proj(x)) class FeedforwardModule(nn.Module):