Reworking of ActivationBalancer code to hopefully balance speed and effectiveness.

This commit is contained in:
Daniel Povey 2022-10-14 19:20:32 +08:00
parent 5f375be159
commit 96023419da

View File

@ -30,6 +30,104 @@ from torch.nn import Embedding as ScaledEmbedding
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
xgt0 = (x > 0)
if sign_factor is None:
ctx.save_for_backward(xgt0, scale_factor)
else:
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
return x
@staticmethod
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None]:
if len(ctx.saved_tensors) == 3:
xgt0, scale_factor, sign_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
sign_factor = sign_factor.unsqueeze(-1)
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
else:
xgt0, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return x_grad - neg_delta_grad, None, None, None,
def _compute_scale_factor(x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
if min_abs == 0.0:
below_threshold = 0.0
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold
def _compute_sign_factor(x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(torch.float32),
dim=sum_dims)
if min_positive == 0.0:
factor1 = 0.0
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
return sign_factor
class ActivationScaleBalancerFunction(torch.autograd.Function):
"""
This object is used in class ActivationBalancer when the user specified
min_positive=0, max_positive=1, so there are no constraints on the signs
of the activations and only the absolute value has a constraint.
"""
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
@ -62,6 +160,7 @@ class ActivationBalancerFunction(torch.autograd.Function):
class MaxEigLimiterFunction(torch.autograd.Function): class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
@ -218,7 +317,6 @@ class ActivationBalancer(torch.nn.Module):
interpolated from 1 at the threshold to those extremal values when none interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive. of the inputs are positive.
Args: Args:
num_channels: the number of channels num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g. channel_dim: the dimension/axis corresponding to the channel, e.g.
@ -231,20 +329,23 @@ class ActivationBalancer(torch.nn.Module):
either the sign constraint or the magnitude constraint; either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02]. values in the range [0.98..1.02].
sign_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_positive and max_positive
are violated.
scale_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_abs and max_abs
are violated.
min_abs: the minimum average-absolute-value difference from the mean min_abs: the minimum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
max_abs: the maximum average-absolute-value difference from the mean max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
beta: a constant used in decaying stats for the {min,max}_positive and min_prob: determines the minimum probability with which we modify the
{min,max}_abs constraints. Likely not critical.
prob: determines the probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints, gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers on each forward(). This is done randomly to prevent all layers
from doing it at the same time. from doing it at the same time. Early in training we may use
stats_period: the periodicity with which we update the statistics on higher probabilities than this; it will decay to this value.
the activations.
""" """
def __init__( def __init__(
self, self,
@ -252,13 +353,12 @@ class ActivationBalancer(torch.nn.Module):
channel_dim: int, channel_dim: int,
min_positive: float = 0.05, min_positive: float = 0.05,
max_positive: float = 0.95, max_positive: float = 0.95,
max_factor: float = 0.01, max_factor: float = 0.02,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
max_var_per_eig: float = 0.0, min_prob: float = 0.1,
beta: float = 0.0,
prob: float = 0.25,
stats_period: int = 4,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
@ -268,124 +368,58 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
self.beta = beta self.min_prob = min_prob
self.prob = prob self.sign_gain_factor = sign_gain_factor
self.stats_period = stats_period self.scale_gain_factor = scale_gain_factor
# count measures how many times the forward() function has been called. # count measures how many times the forward() function has been called.
self.count = 0 # We occasionally sync this to a tensor called `count`, that exists to
# make sure it is synced to disk when we load and save the model.
self.cpu_count = 0
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
# the mean of the absolute value of the data per channel
self.register_buffer('abs_mean', torch.zeros(num_channels))
# the proportion of activations that are positive, per channel.
self.register_buffer('proportion_positive', torch.zeros(num_channels))
# `factors` contains two buffers of shape (num_channels,).
# `sign_factor` is an expression that will be used to scale the
# gradients in backprop; it will be 0 if the max_positive and min_positive
# contstraints are satisfied.
# `scale_factor` is an expression that will be used to encourage the
# data to satisfy our min_abs and max_abs constraints; it will be zero if
# all constraints are satisfied.
self.register_buffer('factors', torch.zeros(2, num_channels))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not x.requires_grad: if torch.jit.is_scripting() or not x.requires_grad:
return x return x
count = self.count count = self.cpu_count
self.count += 1 self.cpu_count += 1
if count % self.stats_period == 0: if random.random() < 0.01:
self._update_stats(x, count) # Occasionally sync self.cpu_count with self.count.
# count affects the decay of 'prob'. don't do this on every iter,
# because syncing with the GPU is slow.
self.cpu_count = max(self.cpu_count, self.count.item())
self.count.fill_(self.cpu_count)
if random.random() < self.prob: # the prob of doing some work exponentially decreases from 0.5 till it hits
# The .clone() is in case the forward() gets called multiple times befor # a floor at min_prob (==0.1, by default)
factors = self.factors.clone() prob = max(self.min_prob, 0.5 ** (1 + (count/2000.0)))
sign_factor = factors[0]
scale_factor = factors[1] if random.random() < prob:
sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0:
sign_factor = _compute_sign_factor(x, self.channel_dim,
self.min_positive, self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor)
else:
sign_factor = None
scale_factor = _compute_scale_factor(x, self.channel_dim,
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor)
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, sign_factor, scale_factor, self.channel_dim, x, scale_factor, sign_factor, self.channel_dim,
) )
else: else:
return x return x
def _update_stats(self,
x: Tensor,
count: int):
"""
Updates some statistics that we maintain, describing the average activations per
channel.
"""
with torch.no_grad():
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
# the random.random() thing is to split the difference if x is zero,
# between treating it positive or negative
proportion_positive = torch.mean(
((x > 0) if random.random() < 0.5 else (x >= 0)).to(torch.float32), dim = sum_dims,
)
def filter_inf_nan(y):
mask = (y - y != 0)
y.masked_fill_(mask, 0.0)
filter_inf_nan(x_abs_mean)
beta = self.beta if count > 0 else 0.0
self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta))
self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta))
max_factor = self.max_factor / self.prob
min_positive = self.min_positive
max_positive = self.max_positive
if min_positive == 0.0:
factor1 = 0.0
else:
# 0 if self.proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - self.proportion_positive).relu() *
(max_factor / min_positive))
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((self.proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0)))
sign_factor = self.factors[0]
scale_factor = self.factors[1]
sign_factor[:] = factor1 + factor2
# the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in
# the backprop, we do (xgt0.to(dtype) - 0.5).
#
# scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1,
# that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs)
# less strong than those from the sign constraints.
#
# This is to get rid of a pathology that can happen if, for instance, a
# channel is always positive but is too small (max_positive and min_abs constraints both
# violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the
# min_positive constraint (trying to make the activation more negative) and from the
# min_abs constraint (trying to make the activation more positive) would exactly cancel.
# Instead we make the min_positive constraint stronger, so it first makes the value
# sometimes negative, and only when that is satisfied, can deal with the absolute-value
# constraint.
scale_factor_scale = 0.8
below_threshold = (self.abs_mean < self.min_abs)
above_threshold = (self.abs_mean > self.max_abs)
scale_factor[:] = ((below_threshold.to(torch.float32) -
above_threshold.to(torch.float32))
* (max_factor * (2.0 * scale_factor_scale)))
class MaxEig(torch.nn.Module): class MaxEig(torch.nn.Module):
""" """
@ -612,7 +646,6 @@ def _test_activation_balancer_sign():
max_positive=0.95, max_positive=0.95,
max_factor=0.2, max_factor=0.2,
min_abs=0.0, min_abs=0.0,
prob=1.0,
) )
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -640,7 +673,7 @@ def _test_activation_balancer_magnitude():
max_factor=0.2, max_factor=0.2,
min_abs=0.2, min_abs=0.2,
max_abs=0.8, max_abs=0.8,
prob=1.0, min_prob=1.0,
) )
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))