mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add max-abs-value constraint in DerivBalancer
This commit is contained in:
parent
6042c96db2
commit
e6a501d3c8
@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor,
|
||||
channel_dim: int,
|
||||
threshold: float = 0.05,
|
||||
max_factor: float = 0.05,
|
||||
min_abs: float = 0.5) -> Tensor:
|
||||
threshold: float, # e.g. 0.05
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 1000.0
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function):
|
||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
||||
|
||||
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = (mean_abs < min_abs)
|
||||
above_threshold = (mean_abs > max_abs)
|
||||
|
||||
ctx.save_for_backward(factor, xgt0, below_threshold)
|
||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||
factor, xgt0, below_threshold = ctx.saved_tensors
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
|
||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
|
||||
|
||||
return x_grad - neg_delta_grad, None, None, None, None
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module):
|
||||
than zero is that it will tend to prevent situations where the
|
||||
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
||||
behaves like a linear function and we learn nothing.
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this. This is to prevent a failure mode where the activations
|
||||
become so small that the nonlinearity effectively becomes linear,
|
||||
which makes the module useless and it gets even smaller
|
||||
to try to "turn it off" completely.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this. This is to prevent the possibility of activations getting
|
||||
out of floating point numerical range (especially in half precision).
|
||||
"""
|
||||
def __init__(self, channel_dim: int,
|
||||
threshold: float = 0.05,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2):
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 1000.0):
|
||||
super(DerivBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.threshold = threshold
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
||||
self.max_factor, self.min_abs)
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2c",
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user