mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint.
This commit is contained in:
parent
b736bb4840
commit
9e30f2bf12
@ -34,6 +34,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
mean: Tensor,
|
||||
sign_factor: Tensor,
|
||||
scale_factor: Tensor,
|
||||
channel_dim: int,
|
||||
@ -41,8 +42,13 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
xgt0 = (x > 0)
|
||||
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
||||
for _ in range(ctx.channel_dim, x.ndim - 1):
|
||||
mean = mean.unsqueeze(-1)
|
||||
sign_factor = sign_factor.unsqueeze(-1)
|
||||
scale_factor = scale_factor.unsqueeze(-1)
|
||||
|
||||
xgtmean = (x > mean)
|
||||
ctx.save_for_backward(xgtmean, sign_factor, scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
@ -50,14 +56,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None]:
|
||||
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
||||
sign_factor = sign_factor.unsqueeze(-1)
|
||||
scale_factor = scale_factor.unsqueeze(-1)
|
||||
xgtmean, sign_factor, scale_factor = ctx.saved_tensors
|
||||
|
||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
||||
factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
|
||||
neg_delta_grad = x_grad.abs() * factor
|
||||
return x_grad - neg_delta_grad, None, None, None,
|
||||
return x_grad - neg_delta_grad, None, None, None, None,
|
||||
|
||||
|
||||
|
||||
@ -275,6 +278,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
# count measures how many times the forward() function has been called.
|
||||
self.count = 0
|
||||
|
||||
# the mean of the data per channel
|
||||
self.register_buffer('mean', torch.zeros(num_channels))
|
||||
|
||||
# the mean of the absolute value of the data per channel
|
||||
self.register_buffer('abs_mean', torch.zeros(num_channels))
|
||||
|
||||
@ -307,7 +313,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
sign_factor = factors[0]
|
||||
scale_factor = factors[1]
|
||||
return ActivationBalancerFunction.apply(
|
||||
x, sign_factor, scale_factor, self.channel_dim,
|
||||
x, self.mean, sign_factor, scale_factor, self.channel_dim,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
@ -322,6 +328,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
||||
|
||||
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
|
||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
||||
# the random.random() thing is to split the difference if x is zero,
|
||||
# between treating it positive or negative
|
||||
@ -333,9 +340,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
mask = (y - y != 0)
|
||||
y.masked_fill_(mask, 0.0)
|
||||
|
||||
filter_inf_nan(x_mean)
|
||||
filter_inf_nan(x_abs_mean)
|
||||
|
||||
beta = self.beta if count > 0 else 0.0
|
||||
self.mean.mul_(beta).add_(x_mean, alpha=(1-beta))
|
||||
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
|
||||
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
|
||||
|
||||
@ -363,25 +372,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
|
||||
# the backprop, we do (xgt0.to(dtype) - 0.5).
|
||||
#
|
||||
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
|
||||
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
|
||||
# less strong than those from the sign constraints.
|
||||
#
|
||||
# This is to get rid of a pathology that can happen if, for instance, a
|
||||
# channel is always positive but is too small (max_positive and min_abs constraints both
|
||||
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
|
||||
# min_positive constraint (trying to make the activation more negative) and from the
|
||||
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
|
||||
# Instead we make the min_positive constraint stronger, so it first makes the value
|
||||
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
|
||||
# constraint.
|
||||
scale_factor_scale = 0.5
|
||||
below_threshold = (self.abs_mean < self.min_abs)
|
||||
above_threshold = (self.abs_mean > self.max_abs)
|
||||
scale_factor[:] = ((below_threshold.to(torch.float32) -
|
||||
above_threshold.to(torch.float32))
|
||||
* (max_factor * (2.0 * scale_factor_scale)))
|
||||
* (max_factor * 2.0))
|
||||
|
||||
|
||||
class MaxEig(torch.nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user