Add max-abs-value constraint in DerivBalancer

This commit is contained in:
Daniel Povey 2022-03-13 11:52:13 +08:00
parent 6042c96db2
commit e6a501d3c8
2 changed files with 31 additions and 13 deletions

View File

@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, def forward(ctx, x: Tensor,
channel_dim: int, channel_dim: int,
threshold: float = 0.05, threshold: float, # e.g. 0.05
max_factor: float = 0.05, max_factor: float, # e.g. 0.01
min_abs: float = 0.5) -> Tensor: min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 1000.0
) -> Tensor:
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function):
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold) factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs)
above_threshold = (mean_abs > max_abs)
ctx.save_for_backward(factor, xgt0, below_threshold) ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
ctx.max_factor = max_factor ctx.max_factor = max_factor
ctx.sum_dims = sum_dims ctx.sum_dims = sum_dims
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
factor, xgt0, below_threshold = ctx.saved_tensors factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype dtype = x_grad.dtype
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
neg_delta_grad = x_grad.abs() * (factor + too_small_factor) neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module):
than zero is that it will tend to prevent situations where the than zero is that it will tend to prevent situations where the
inputs shrink close to zero and the nonlinearity (e.g. swish) inputs shrink close to zero and the nonlinearity (e.g. swish)
behaves like a linear function and we learn nothing. behaves like a linear function and we learn nothing.
min_abs: the minimum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent a failure mode where the activations
become so small that the nonlinearity effectively becomes linear,
which makes the module useless and it gets even smaller
to try to "turn it off" completely.
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this. This is to prevent the possibility of activations getting
out of floating point numerical range (especially in half precision).
""" """
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
threshold: float = 0.05, threshold: float = 0.05,
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2): min_abs: float = 0.2,
max_abs: float = 1000.0):
super(DerivBalancer, self).__init__() super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.threshold = threshold self.threshold = threshold
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.min_abs) self.max_factor, self.min_abs,
self.max_abs)
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c", default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved