Modify scaling.py to prevent constant values

This commit is contained in:
Daniel Povey 2022-07-29 09:34:13 +08:00
parent 3c1fddaf48
commit 9d7af4be20

View File

@ -58,9 +58,10 @@ class ActivationBalancerFunction(torch.autograd.Function):
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0
x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True)
xgtmean = (x_normalized > 0)
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
(x > 0).to(x.dtype), dim=sum_dims, keepdim=True
)
factor1 = (
(min_positive - proportion_positive).relu()
@ -74,16 +75,24 @@ class ActivationBalancerFunction(torch.autograd.Function):
if max_positive != 1.0
else 0.0
)
# `factor` is a tensor of shape something like (1, 1, num_channels,
# 1), containing elements between -1 and 1 that are zero if the
# proportion of positive features is between min_positive and
# max_positive, max_factor if proportion==0.0 (all features are negative),
# and -max_factor if proportion==1.0 (all features are positive). It is
# an amount per channel by which we'll modify the gradient; the sign
# of modifying the gradient will depend on the sign of the gradient.
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
below_threshold = mean_abs < min_abs
above_threshold = mean_abs > max_abs
ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold
factor, xgtmean, below_threshold, above_threshold
)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
@ -93,11 +102,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype
scale_factor = (
(below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5)
* (xgtmean.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
@ -105,6 +114,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
@ -404,12 +415,12 @@ class ActivationBalancer(torch.nn.Module):
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02].
min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
min_abs: the minimum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
"""
def __init__(