Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint.

This commit is contained in:
Daniel Povey 2022-10-13 12:05:45 +08:00
parent b736bb4840
commit 9e30f2bf12

View File

@ -34,6 +34,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
def forward(
ctx,
x: Tensor,
mean: Tensor,
sign_factor: Tensor,
scale_factor: Tensor,
channel_dim: int,
@ -41,8 +42,13 @@ class ActivationBalancerFunction(torch.autograd.Function):
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
xgt0 = (x > 0)
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
for _ in range(ctx.channel_dim, x.ndim - 1):
mean = mean.unsqueeze(-1)
sign_factor = sign_factor.unsqueeze(-1)
scale_factor = scale_factor.unsqueeze(-1)
xgtmean = (x > mean)
ctx.save_for_backward(xgtmean, sign_factor, scale_factor)
return x
@ -50,14 +56,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None]:
xgt0, sign_factor, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
sign_factor = sign_factor.unsqueeze(-1)
scale_factor = scale_factor.unsqueeze(-1)
xgtmean, sign_factor, scale_factor = ctx.saved_tensors
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return x_grad - neg_delta_grad, None, None, None,
return x_grad - neg_delta_grad, None, None, None, None,
@ -275,6 +278,9 @@ class ActivationBalancer(torch.nn.Module):
# count measures how many times the forward() function has been called.
self.count = 0
# the mean of the data per channel
self.register_buffer('mean', torch.zeros(num_channels))
# the mean of the absolute value of the data per channel
self.register_buffer('abs_mean', torch.zeros(num_channels))
@ -307,7 +313,7 @@ class ActivationBalancer(torch.nn.Module):
sign_factor = factors[0]
scale_factor = factors[1]
return ActivationBalancerFunction.apply(
x, sign_factor, scale_factor, self.channel_dim,
x, self.mean, sign_factor, scale_factor, self.channel_dim,
)
else:
return x
@ -322,6 +328,7 @@ class ActivationBalancer(torch.nn.Module):
with torch.no_grad():
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
# the random.random() thing is to split the difference if x is zero,
# between treating it positive or negative
@ -333,9 +340,11 @@ class ActivationBalancer(torch.nn.Module):
mask = (y - y != 0)
y.masked_fill_(mask, 0.0)
filter_inf_nan(x_mean)
filter_inf_nan(x_abs_mean)
beta = self.beta if count > 0 else 0.0
self.mean.mul_(beta).add_(x_mean, alpha=(1-beta))
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
@ -363,25 +372,11 @@ class ActivationBalancer(torch.nn.Module):
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
# the backprop, we do (xgt0.to(dtype) - 0.5).
#
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
# less strong than those from the sign constraints.
#
# This is to get rid of a pathology that can happen if, for instance, a
# channel is always positive but is too small (max_positive and min_abs constraints both
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
# min_positive constraint (trying to make the activation more negative) and from the
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
# Instead we make the min_positive constraint stronger, so it first makes the value
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
# constraint.
scale_factor_scale = 0.5
below_threshold = (self.abs_mean < self.min_abs)
above_threshold = (self.abs_mean > self.max_abs)
scale_factor[:] = ((below_threshold.to(torch.float32) -
above_threshold.to(torch.float32))
* (max_factor * (2.0 * scale_factor_scale)))
* (max_factor * 2.0))
class MaxEig(torch.nn.Module):