diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 6e2a4deb4..35ae7624c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1260,7 +1260,8 @@ class FeedforwardModule(nn.Module): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0) + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) self.out_proj = ScaledLinear(feedforward_dim, d_model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 5b4fece3d..676110675 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -520,7 +520,7 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/2000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) if random.random() < prob: sign_gain_factor = 0.5