mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add max-abs-value constraint in DerivBalancer
This commit is contained in:
parent
6042c96db2
commit
e6a501d3c8
@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor,
|
def forward(ctx, x: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
threshold: float = 0.05,
|
threshold: float, # e.g. 0.05
|
||||||
max_factor: float = 0.05,
|
max_factor: float, # e.g. 0.01
|
||||||
min_abs: float = 0.5) -> Tensor:
|
min_abs: float, # e.g. 0.2
|
||||||
|
max_abs: float, # e.g. 1000.0
|
||||||
|
) -> Tensor:
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||||
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
||||||
|
|
||||||
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||||
|
below_threshold = (mean_abs < min_abs)
|
||||||
|
above_threshold = (mean_abs > max_abs)
|
||||||
|
|
||||||
ctx.save_for_backward(factor, xgt0, below_threshold)
|
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
||||||
ctx.max_factor = max_factor
|
ctx.max_factor = max_factor
|
||||||
ctx.sum_dims = sum_dims
|
ctx.sum_dims = sum_dims
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||||
factor, xgt0, below_threshold = ctx.saved_tensors
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||||
dtype = x_grad.dtype
|
dtype = x_grad.dtype
|
||||||
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
|
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||||
|
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||||
|
|
||||||
neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||||
|
|
||||||
|
|
||||||
return x_grad - neg_delta_grad, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
than zero is that it will tend to prevent situations where the
|
than zero is that it will tend to prevent situations where the
|
||||||
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
inputs shrink close to zero and the nonlinearity (e.g. swish)
|
||||||
behaves like a linear function and we learn nothing.
|
behaves like a linear function and we learn nothing.
|
||||||
|
min_abs: the minimum average-absolute-value per channel, which
|
||||||
|
we allow, before we start to modify the derivatives to prevent
|
||||||
|
this. This is to prevent a failure mode where the activations
|
||||||
|
become so small that the nonlinearity effectively becomes linear,
|
||||||
|
which makes the module useless and it gets even smaller
|
||||||
|
to try to "turn it off" completely.
|
||||||
|
max_abs: the maximum average-absolute-value per channel, which
|
||||||
|
we allow, before we start to modify the derivatives to prevent
|
||||||
|
this. This is to prevent the possibility of activations getting
|
||||||
|
out of floating point numerical range (especially in half precision).
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
def __init__(self, channel_dim: int,
|
||||||
threshold: float = 0.05,
|
threshold: float = 0.05,
|
||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2):
|
min_abs: float = 0.2,
|
||||||
|
max_abs: float = 1000.0):
|
||||||
super(DerivBalancer, self).__init__()
|
super(DerivBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.min_abs = min_abs
|
self.min_abs = min_abs
|
||||||
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
||||||
self.max_factor, self.min_abs)
|
self.max_factor, self.min_abs,
|
||||||
|
self.max_abs)
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user