Add max-abs-value constraint in DerivBalancer

This commit is contained in:
Daniel Povey 2022-03-13 11:52:13 +08:00
parent 6042c96db2
commit e6a501d3c8
2 changed files with 31 additions and 13 deletions

View File

@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor,
channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.05,
min_abs: float = 0.5) -> Tensor:
threshold: float, # e.g. 0.05
max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 1000.0
) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function):
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs)
above_threshold = (mean_abs > max_abs)
ctx.save_for_backward(factor, xgt0, below_threshold)
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
factor, xgt0, below_threshold = ctx.saved_tensors
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None
return x_grad - neg_delta_grad, None, None, None, None, None
class BasicNorm(torch.nn.Module):
@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module):
than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing.
min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent a failure mode where the activations
become so small that the nonlinearity effectively becomes linear,
which makes the module useless and it gets even smaller
to try to "turn it off" completely.
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent the possibility of activations getting
out of floating point numerical range (especially in half precision).
"""
def __init__(self, channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.01,
min_abs: float = 0.2):
min_abs: float = 0.2,
max_abs: float = 1000.0):
super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim
self.threshold = threshold
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.min_abs)
self.max_factor, self.min_abs,
self.max_abs)
class DoubleSwish(torch.nn.Module):

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c",
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved