diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 042beb21e..2da9c6445 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1384,7 +1384,9 @@ class FeedforwardModule(nn.Module): feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) + self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, + aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) + self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25) @@ -1392,8 +1394,7 @@ class FeedforwardModule(nn.Module): self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, - aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)), - ) + aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25),