Only apply ActivationBalancer with prob 0.25.

This commit is contained in:
Daniel Povey 2022-10-10 00:26:31 +08:00
parent dece8ad204
commit d7f6e8eb51

View File

@ -154,12 +154,9 @@ class MaxEigLimiterFunction(torch.autograd.Function):
x: Tensor,
direction: Tensor,
channel_dim: int,
prob: float,
subtract_mean: bool,
max_variance_proportion: float,
grad_scale: float) -> Tuple[Tensor, Tensor]:
if random.random() > prob:
return x, direction
eps = 1.0e-20
num_channels = x.shape[channel_dim]
assert max_variance_proportion > 1.0 / num_channels
@ -396,28 +393,31 @@ class ActivationBalancer(torch.nn.Module):
if torch.jit.is_scripting():
return x
if self.max_var_per_eig > 0:
max_eig_prob = 0.25
max_eig_prob = 0.25
if self.max_var_per_eig > 0 and random.random() < max_eig_prob:
with torch.cuda.amp.autocast(enabled=False):
x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction,
self.channel_dim,
max_eig_prob,
True, # subtract_mean
self.max_var_per_eig,
self.max_factor / max_eig_prob,
)
self.max_eig_direction[:] = new_direction.detach()
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
balance_prob = 0.25
if random.random() < balance_prob:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / balance_prob,
self.min_abs,
self.max_abs,
)
else:
return x
class DoubleSwishFunction(torch.autograd.Function):
@ -473,7 +473,6 @@ def _test_max_eig_limiter():
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
1, # channel_dim
1.0, # prob
True, # subtract_mean
0.5, # max_variance_proportion
0.1, # grad_scale