From d7f6e8eb514f80b7d779c4dd97f6b6e67012a62d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Oct 2022 00:26:31 +0800 Subject: [PATCH] Only apply ActivationBalancer with prob 0.25. --- .../pruned_transducer_stateless7/scaling.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 8432e4a47..24ddf892f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -154,12 +154,9 @@ class MaxEigLimiterFunction(torch.autograd.Function): x: Tensor, direction: Tensor, channel_dim: int, - prob: float, subtract_mean: bool, max_variance_proportion: float, grad_scale: float) -> Tuple[Tensor, Tensor]: - if random.random() > prob: - return x, direction eps = 1.0e-20 num_channels = x.shape[channel_dim] assert max_variance_proportion > 1.0 / num_channels @@ -396,28 +393,31 @@ class ActivationBalancer(torch.nn.Module): if torch.jit.is_scripting(): return x - if self.max_var_per_eig > 0: - max_eig_prob = 0.25 + max_eig_prob = 0.25 + if self.max_var_per_eig > 0 and random.random() < max_eig_prob: with torch.cuda.amp.autocast(enabled=False): x, new_direction = MaxEigLimiterFunction.apply( x, self.max_eig_direction, self.channel_dim, - max_eig_prob, True, # subtract_mean self.max_var_per_eig, self.max_factor / max_eig_prob, ) self.max_eig_direction[:] = new_direction.detach() - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + balance_prob = 0.25 + if random.random() < balance_prob: + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor / balance_prob, + self.min_abs, + self.max_abs, + ) + else: + return x class DoubleSwishFunction(torch.autograd.Function): @@ -473,7 +473,6 @@ def _test_max_eig_limiter(): y, new_direction = MaxEigLimiterFunction.apply(x, direction, 1, # channel_dim - 1.0, # prob True, # subtract_mean 0.5, # max_variance_proportion 0.1, # grad_scale