mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Only apply ActivationBalancer with prob 0.25.
This commit is contained in:
parent
dece8ad204
commit
d7f6e8eb51
@ -154,12 +154,9 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
x: Tensor,
|
x: Tensor,
|
||||||
direction: Tensor,
|
direction: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
prob: float,
|
|
||||||
subtract_mean: bool,
|
subtract_mean: bool,
|
||||||
max_variance_proportion: float,
|
max_variance_proportion: float,
|
||||||
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
||||||
if random.random() > prob:
|
|
||||||
return x, direction
|
|
||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
num_channels = x.shape[channel_dim]
|
num_channels = x.shape[channel_dim]
|
||||||
assert max_variance_proportion > 1.0 / num_channels
|
assert max_variance_proportion > 1.0 / num_channels
|
||||||
@ -396,28 +393,31 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if self.max_var_per_eig > 0:
|
|
||||||
max_eig_prob = 0.25
|
max_eig_prob = 0.25
|
||||||
|
if self.max_var_per_eig > 0 and random.random() < max_eig_prob:
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x, new_direction = MaxEigLimiterFunction.apply(
|
x, new_direction = MaxEigLimiterFunction.apply(
|
||||||
x, self.max_eig_direction,
|
x, self.max_eig_direction,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
max_eig_prob,
|
|
||||||
True, # subtract_mean
|
True, # subtract_mean
|
||||||
self.max_var_per_eig,
|
self.max_var_per_eig,
|
||||||
self.max_factor / max_eig_prob,
|
self.max_factor / max_eig_prob,
|
||||||
)
|
)
|
||||||
self.max_eig_direction[:] = new_direction.detach()
|
self.max_eig_direction[:] = new_direction.detach()
|
||||||
|
|
||||||
|
balance_prob = 0.25
|
||||||
|
if random.random() < balance_prob:
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x,
|
x,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
self.min_positive,
|
self.min_positive,
|
||||||
self.max_positive,
|
self.max_positive,
|
||||||
self.max_factor,
|
self.max_factor / balance_prob,
|
||||||
self.min_abs,
|
self.min_abs,
|
||||||
self.max_abs,
|
self.max_abs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
@ -473,7 +473,6 @@ def _test_max_eig_limiter():
|
|||||||
|
|
||||||
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
||||||
1, # channel_dim
|
1, # channel_dim
|
||||||
1.0, # prob
|
|
||||||
True, # subtract_mean
|
True, # subtract_mean
|
||||||
0.5, # max_variance_proportion
|
0.5, # max_variance_proportion
|
||||||
0.1, # grad_scale
|
0.1, # grad_scale
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user