From f625810de1221580e607554ce73cc71993fbf9ee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Nov 2022 19:21:37 +0800 Subject: [PATCH] Use the balancer; remove the unused sigmoid module. --- .../pruned_transducer_stateless7/zipformer.py | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9878939be..5f12694fb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1436,14 +1436,9 @@ class ModifiedSEModule(nn.Module): d_model: int, bottleneck_dim: int = 8): super().__init__() - self.squeeze_proj = nn.Linear(d_model, d_model, + self.squeeze_proj = nn.Linear(d_model, bottleneck_dim, bias=False) - # caution: this won't work well if the batch size is extremely small. - self.squeeze_whiten = Whiten(num_groups=1, - whitening_limit=10.0, - prob=(0.025, 0.25), - grad_scale=0.01) self.in_proj = nn.Linear(d_model, d_model, bias=False) @@ -1456,23 +1451,14 @@ class ModifiedSEModule(nn.Module): self.balancer = ActivationBalancer( d_model, channel_dim=-1, min_positive=0.05, max_positive=0.95, + min_abs=0.1, max_abs=50.0, - max_factor=0.01, + max_factor=0.02, min_prob=0.2, ) self.activation = DoubleSwish() - self.to_bottleneck_proj = ScaledLinear(d_model, bottleneck_dim) - self.bottleneck_balancer = ActivationBalancer( - bottleneck_dim, channel_dim=-1, - min_positive=0.05, max_positive=0.95, - max_abs=5.0, - min_abs=0.5, - max_factor=0.01, - min_prob=0.2, - ) self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model) - self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes. self.out_proj = ScaledLinear(d_model, d_model, bias=False, initial_scale=0.1) @@ -1501,17 +1487,9 @@ class ModifiedSEModule(nn.Module): squeezed = (x * pooling_mask).sum(dim=0, keepdim=True) squeezed = self.squeeze_proj(squeezed) - squeezed = self.squeeze_whiten(squeezed) squeezed = self.balancer(squeezed) squeezed = self.activation(squeezed) - squeezed = self.to_bottleneck_proj(squeezed) - squeezed = self.bottleneck_balancer(squeezed) squeezed = self.from_bottleneck_proj(squeezed) - if random.random() < 0.05: - # to stop a hopefully-unlikely failure mode where the inputs to the sigmoid - # get too large and the grads get mostly too small. - squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04) - scales = self.sigmoid(squeezed) x = self.in_proj(x) x = x * squeezed