mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Only apply ActivationBalancer with prob 0.25.
This commit is contained in:
parent
dece8ad204
commit
d7f6e8eb51
@ -154,12 +154,9 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
x: Tensor,
|
||||
direction: Tensor,
|
||||
channel_dim: int,
|
||||
prob: float,
|
||||
subtract_mean: bool,
|
||||
max_variance_proportion: float,
|
||||
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
||||
if random.random() > prob:
|
||||
return x, direction
|
||||
eps = 1.0e-20
|
||||
num_channels = x.shape[channel_dim]
|
||||
assert max_variance_proportion > 1.0 / num_channels
|
||||
@ -396,28 +393,31 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
|
||||
if self.max_var_per_eig > 0:
|
||||
max_eig_prob = 0.25
|
||||
if self.max_var_per_eig > 0 and random.random() < max_eig_prob:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x, new_direction = MaxEigLimiterFunction.apply(
|
||||
x, self.max_eig_direction,
|
||||
self.channel_dim,
|
||||
max_eig_prob,
|
||||
True, # subtract_mean
|
||||
self.max_var_per_eig,
|
||||
self.max_factor / max_eig_prob,
|
||||
)
|
||||
self.max_eig_direction[:] = new_direction.detach()
|
||||
|
||||
balance_prob = 0.25
|
||||
if random.random() < balance_prob:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.max_factor / balance_prob,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -473,7 +473,6 @@ def _test_max_eig_limiter():
|
||||
|
||||
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
||||
1, # channel_dim
|
||||
1.0, # prob
|
||||
True, # subtract_mean
|
||||
0.5, # max_variance_proportion
|
||||
0.1, # grad_scale
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user