Only apply ActivationBalancer with prob 0.25.

This commit is contained in:
Daniel Povey 2022-10-10 00:26:31 +08:00
parent dece8ad204
commit d7f6e8eb51

View File

@ -154,12 +154,9 @@ class MaxEigLimiterFunction(torch.autograd.Function):
x: Tensor, x: Tensor,
direction: Tensor, direction: Tensor,
channel_dim: int, channel_dim: int,
prob: float,
subtract_mean: bool, subtract_mean: bool,
max_variance_proportion: float, max_variance_proportion: float,
grad_scale: float) -> Tuple[Tensor, Tensor]: grad_scale: float) -> Tuple[Tensor, Tensor]:
if random.random() > prob:
return x, direction
eps = 1.0e-20 eps = 1.0e-20
num_channels = x.shape[channel_dim] num_channels = x.shape[channel_dim]
assert max_variance_proportion > 1.0 / num_channels assert max_variance_proportion > 1.0 / num_channels
@ -396,28 +393,31 @@ class ActivationBalancer(torch.nn.Module):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return x return x
if self.max_var_per_eig > 0:
max_eig_prob = 0.25 max_eig_prob = 0.25
if self.max_var_per_eig > 0 and random.random() < max_eig_prob:
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x, new_direction = MaxEigLimiterFunction.apply( x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction, x, self.max_eig_direction,
self.channel_dim, self.channel_dim,
max_eig_prob,
True, # subtract_mean True, # subtract_mean
self.max_var_per_eig, self.max_var_per_eig,
self.max_factor / max_eig_prob, self.max_factor / max_eig_prob,
) )
self.max_eig_direction[:] = new_direction.detach() self.max_eig_direction[:] = new_direction.detach()
balance_prob = 0.25
if random.random() < balance_prob:
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
self.channel_dim, self.channel_dim,
self.min_positive, self.min_positive,
self.max_positive, self.max_positive,
self.max_factor, self.max_factor / balance_prob,
self.min_abs, self.min_abs,
self.max_abs, self.max_abs,
) )
else:
return x
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
@ -473,7 +473,6 @@ def _test_max_eig_limiter():
y, new_direction = MaxEigLimiterFunction.apply(x, direction, y, new_direction = MaxEigLimiterFunction.apply(x, direction,
1, # channel_dim 1, # channel_dim
1.0, # prob
True, # subtract_mean True, # subtract_mean
0.5, # max_variance_proportion 0.5, # max_variance_proportion
0.1, # grad_scale 0.1, # grad_scale