Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint.
This commit is contained in:
parent
b736bb4840
commit
9e30f2bf12
@ -34,6 +34,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
mean: Tensor,
|
||||||
sign_factor: Tensor,
|
sign_factor: Tensor,
|
||||||
scale_factor: Tensor,
|
scale_factor: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
@ -41,8 +42,13 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
xgt0 = (x > 0)
|
for _ in range(ctx.channel_dim, x.ndim - 1):
|
||||||
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
mean = mean.unsqueeze(-1)
|
||||||
|
sign_factor = sign_factor.unsqueeze(-1)
|
||||||
|
scale_factor = scale_factor.unsqueeze(-1)
|
||||||
|
|
||||||
|
xgtmean = (x > mean)
|
||||||
|
ctx.save_for_backward(xgtmean, sign_factor, scale_factor)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -50,14 +56,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
def backward(
|
def backward(
|
||||||
ctx, x_grad: Tensor
|
ctx, x_grad: Tensor
|
||||||
) -> Tuple[Tensor, None, None, None]:
|
) -> Tuple[Tensor, None, None, None]:
|
||||||
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
xgtmean, sign_factor, scale_factor = ctx.saved_tensors
|
||||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
|
||||||
sign_factor = sign_factor.unsqueeze(-1)
|
|
||||||
scale_factor = scale_factor.unsqueeze(-1)
|
|
||||||
|
|
||||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
|
||||||
neg_delta_grad = x_grad.abs() * factor
|
neg_delta_grad = x_grad.abs() * factor
|
||||||
return x_grad - neg_delta_grad, None, None, None,
|
return x_grad - neg_delta_grad, None, None, None, None,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -275,6 +278,9 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
# count measures how many times the forward() function has been called.
|
# count measures how many times the forward() function has been called.
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
|
# the mean of the data per channel
|
||||||
|
self.register_buffer('mean', torch.zeros(num_channels))
|
||||||
|
|
||||||
# the mean of the absolute value of the data per channel
|
# the mean of the absolute value of the data per channel
|
||||||
self.register_buffer('abs_mean', torch.zeros(num_channels))
|
self.register_buffer('abs_mean', torch.zeros(num_channels))
|
||||||
|
|
||||||
@ -307,7 +313,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
sign_factor = factors[0]
|
sign_factor = factors[0]
|
||||||
scale_factor = factors[1]
|
scale_factor = factors[1]
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x, sign_factor, scale_factor, self.channel_dim,
|
x, self.mean, sign_factor, scale_factor, self.channel_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
@ -322,6 +328,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
||||||
|
|
||||||
|
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
|
||||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
||||||
# the random.random() thing is to split the difference if x is zero,
|
# the random.random() thing is to split the difference if x is zero,
|
||||||
# between treating it positive or negative
|
# between treating it positive or negative
|
||||||
@ -333,9 +340,11 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
mask = (y - y != 0)
|
mask = (y - y != 0)
|
||||||
y.masked_fill_(mask, 0.0)
|
y.masked_fill_(mask, 0.0)
|
||||||
|
|
||||||
|
filter_inf_nan(x_mean)
|
||||||
filter_inf_nan(x_abs_mean)
|
filter_inf_nan(x_abs_mean)
|
||||||
|
|
||||||
beta = self.beta if count > 0 else 0.0
|
beta = self.beta if count > 0 else 0.0
|
||||||
|
self.mean.mul_(beta).add_(x_mean, alpha=(1-beta))
|
||||||
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
|
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
|
||||||
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
|
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
|
||||||
|
|
||||||
@ -363,25 +372,11 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
|
|
||||||
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
|
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
|
||||||
# the backprop, we do (xgt0.to(dtype) - 0.5).
|
# the backprop, we do (xgt0.to(dtype) - 0.5).
|
||||||
#
|
|
||||||
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
|
|
||||||
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
|
|
||||||
# less strong than those from the sign constraints.
|
|
||||||
#
|
|
||||||
# This is to get rid of a pathology that can happen if, for instance, a
|
|
||||||
# channel is always positive but is too small (max_positive and min_abs constraints both
|
|
||||||
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
|
|
||||||
# min_positive constraint (trying to make the activation more negative) and from the
|
|
||||||
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
|
|
||||||
# Instead we make the min_positive constraint stronger, so it first makes the value
|
|
||||||
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
|
|
||||||
# constraint.
|
|
||||||
scale_factor_scale = 0.5
|
|
||||||
below_threshold = (self.abs_mean < self.min_abs)
|
below_threshold = (self.abs_mean < self.min_abs)
|
||||||
above_threshold = (self.abs_mean > self.max_abs)
|
above_threshold = (self.abs_mean > self.max_abs)
|
||||||
scale_factor[:] = ((below_threshold.to(torch.float32) -
|
scale_factor[:] = ((below_threshold.to(torch.float32) -
|
||||||
above_threshold.to(torch.float32))
|
above_threshold.to(torch.float32))
|
||||||
* (max_factor * (2.0 * scale_factor_scale)))
|
* (max_factor * 2.0))
|
||||||
|
|
||||||
|
|
||||||
class MaxEig(torch.nn.Module):
|
class MaxEig(torch.nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user