Modify scaling.py to prevent constant values

This commit is contained in:
Daniel Povey 2022-07-29 09:34:13 +08:00
parent 3c1fddaf48
commit 9d7af4be20

View File

@ -58,9 +58,10 @@ class ActivationBalancerFunction(torch.autograd.Function):
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0 x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True)
xgtmean = (x_normalized > 0)
proportion_positive = torch.mean( proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True (x > 0).to(x.dtype), dim=sum_dims, keepdim=True
) )
factor1 = ( factor1 = (
(min_positive - proportion_positive).relu() (min_positive - proportion_positive).relu()
@ -74,16 +75,24 @@ class ActivationBalancerFunction(torch.autograd.Function):
if max_positive != 1.0 if max_positive != 1.0
else 0.0 else 0.0
) )
# `factor` is a tensor of shape something like (1, 1, num_channels,
# 1), containing elements between -1 and 1 that are zero if the
# proportion of positive features is between min_positive and
# max_positive, max_factor if proportion==0.0 (all features are negative),
# and -max_factor if proportion==1.0 (all features are positive). It is
# an amount per channel by which we'll modify the gradient; the sign
# of modifying the gradient will depend on the sign of the gradient.
factor = factor1 + factor2 factor = factor1 + factor2
if isinstance(factor, float): if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive) factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
below_threshold = mean_abs < min_abs below_threshold = mean_abs < min_abs
above_threshold = mean_abs > max_abs above_threshold = mean_abs > max_abs
ctx.save_for_backward( ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold factor, xgtmean, below_threshold, above_threshold
) )
ctx.max_factor = max_factor ctx.max_factor = max_factor
ctx.sum_dims = sum_dims ctx.sum_dims = sum_dims
@ -93,11 +102,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
def backward( def backward(
ctx, x_grad: Tensor ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]: ) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype dtype = x_grad.dtype
scale_factor = ( scale_factor = (
(below_threshold.to(dtype) - above_threshold.to(dtype)) (below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5) * (xgtmean.to(dtype) - 0.5)
* (ctx.max_factor * 2.0) * (ctx.max_factor * 2.0)
) )
@ -105,6 +114,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
@ -404,12 +415,12 @@ class ActivationBalancer(torch.nn.Module):
either the sign constraint or the magnitude constraint; either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02]. values in the range [0.98..1.02].
min_abs: the minimum average-absolute-value per channel, which min_abs: the minimum average-absolute-value difference from the mean
we allow, before we start to modify the derivatives to prevent value per channel, which we allow, before we start to modify
this. the derivatives to prevent this.
max_abs: the maximum average-absolute-value per channel, which max_abs: the maximum average-absolute-value difference from the mean
we allow, before we start to modify the derivatives to prevent value per channel, which we allow, before we start to modify
this. the derivatives to prevent this.
""" """
def __init__( def __init__(