diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 233b38da8..ce3b08ae8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1381,21 +1381,27 @@ class FeedforwardModule(nn.Module): dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.hidden_balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.01) + self.out_balancer = ActivationBalancer(embed_dim, + min_positive=0.4, max_positive=0.6, + min_abs=0.01, max_abs=5.0, + channel_dim=-1, min_prob=0.1) + def forward(self, x: Tensor): x = self.in_proj(x) - x = self.balancer(x) + x = self.hidden_balancer(x) x = self.activation(x) x = self.dropout(x) x = self.out_proj(x) + x = self.out_balancer(x) return x