Make the ActivationBalancer relative to the mean, limited to -min_abs..max_abs

This commit is contained in:
Daniel Povey 2022-12-09 17:59:00 +08:00
parent 912adfff7c
commit 2ef0228db0

View File

@ -105,17 +105,18 @@ class ActivationBalancerFunction(torch.autograd.Function):
ctx, ctx,
x: Tensor, x: Tensor,
scale_factor: Tensor, scale_factor: Tensor,
mean: Tensor,
sign_factor: Optional[Tensor], sign_factor: Optional[Tensor],
channel_dim: int, channel_dim: int,
) -> Tensor: ) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
xgt0 = (x > 0) xgtmean = (x > mean)
if sign_factor is None: if sign_factor is None:
ctx.save_for_backward(xgt0, scale_factor) ctx.save_for_backward(xgtmean, scale_factor)
else: else:
ctx.save_for_backward(xgt0, scale_factor, sign_factor) ctx.save_for_backward(xgtmean, scale_factor, sign_factor)
return x return x
@ -124,29 +125,48 @@ class ActivationBalancerFunction(torch.autograd.Function):
ctx, x_grad: Tensor ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None]: ) -> Tuple[Tensor, None, None, None]:
if len(ctx.saved_tensors) == 3: if len(ctx.saved_tensors) == 3:
xgt0, scale_factor, sign_factor = ctx.saved_tensors xgtmean, scale_factor, sign_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1): for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1) scale_factor = scale_factor.unsqueeze(-1)
sign_factor = sign_factor.unsqueeze(-1) sign_factor = sign_factor.unsqueeze(-1)
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
else: else:
xgt0, scale_factor = ctx.saved_tensors xgtmean, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1): for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1) scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) factor = scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor neg_delta_grad = x_grad.abs() * factor
return x_grad - neg_delta_grad, None, None, None, return x_grad - neg_delta_grad, None, None, None, None
def _compute_scale_factor(x: Tensor, def _compute_scale_factor(x: Tensor,
channel_dim: int, channel_dim: int,
min_abs: float, min_abs: float,
max_abs: float, max_abs: float,
gain_factor: float, gain_factor: float,
max_factor: float) -> Tensor: max_factor: float) -> Tuple[Tensor, Tensor]:
"""
Computes a factor used in ActivationBalancer, that dictates how much we penalize (or anti-penalize)
the scale on the features.
Returns: (scale_factor, mean)
dim.
scale_factor: can be positive or negative, between -max_factor and max_factor; dictates
penalty or anti-penalty. It is of shape (num_channels,)
mean: mean per channel that we use for purposes of scale_factor; actually is clamped to
-min_abs..min_abs. Its like (1, num_channels, 1, 1) depending on the shape of x and
channel-dim.
"""
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
x_mean = torch.mean(x, dim=sum_dims, keepdim=True).to(torch.float32)
# the idea is that for purposes of applying max_abs, we regress effectively
# toward zero (assuming min_abs is much less than max_abs).
x_mean = x_mean.clamp(min=-min_abs, max=min_abs)
x_abs_mean = torch.mean((x - x_mean).abs(), dim=sum_dims).to(torch.float32)
if min_abs == 0.0: if min_abs == 0.0:
below_threshold = 0.0 below_threshold = 0.0
@ -157,7 +177,7 @@ def _compute_scale_factor(x: Tensor,
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold return below_threshold - above_threshold, x_mean
def _compute_sign_factor(x: Tensor, def _compute_sign_factor(x: Tensor,
channel_dim: int, channel_dim: int,
@ -679,13 +699,13 @@ class ActivationBalancer(torch.nn.Module):
sign_factor = None sign_factor = None
scale_factor = _compute_scale_factor(x, self.channel_dim, scale_factor, mean = _compute_scale_factor(x, self.channel_dim,
min_abs=float(self.min_abs), min_abs=float(self.min_abs),
max_abs=float(self.max_abs), max_abs=float(self.max_abs),
gain_factor=float(self.scale_gain_factor) / prob, gain_factor=float(self.scale_gain_factor) / prob,
max_factor=float(self.max_factor)) max_factor=float(self.max_factor))
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, scale_factor, sign_factor, self.channel_dim, x, scale_factor, mean, sign_factor, self.channel_dim,
) )
else: else:
return _no_op(x) return _no_op(x)