diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 3fe71698b..96906b726 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -397,14 +397,15 @@ class ActivationBalancer(torch.nn.Module): return x if self.max_var_per_eig > 0: + max_eig_prob = 0.25 with torch.cuda.amp.autocast(enabled=False): x, new_direction = MaxEigLimiterFunction.apply( x, self.max_eig_direction, self.channel_dim, - 0.25, # prob + max_eig_prob, True, # subtract_mean self.max_var_per_eig, - self.max_factor, + self.max_factor / max_eig_prob, ) self.max_eig_direction[:] = new_direction.detach()