mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make the ActivationBalancer relative to the mean, limited to -min_abs..max_abs
This commit is contained in:
parent
912adfff7c
commit
2ef0228db0
@ -105,17 +105,18 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
ctx,
|
ctx,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
scale_factor: Tensor,
|
scale_factor: Tensor,
|
||||||
|
mean: Tensor,
|
||||||
sign_factor: Optional[Tensor],
|
sign_factor: Optional[Tensor],
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
xgt0 = (x > 0)
|
xgtmean = (x > mean)
|
||||||
if sign_factor is None:
|
if sign_factor is None:
|
||||||
ctx.save_for_backward(xgt0, scale_factor)
|
ctx.save_for_backward(xgtmean, scale_factor)
|
||||||
else:
|
else:
|
||||||
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
ctx.save_for_backward(xgtmean, scale_factor, sign_factor)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -124,29 +125,48 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
ctx, x_grad: Tensor
|
ctx, x_grad: Tensor
|
||||||
) -> Tuple[Tensor, None, None, None]:
|
) -> Tuple[Tensor, None, None, None]:
|
||||||
if len(ctx.saved_tensors) == 3:
|
if len(ctx.saved_tensors) == 3:
|
||||||
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
xgtmean, scale_factor, sign_factor = ctx.saved_tensors
|
||||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
||||||
scale_factor = scale_factor.unsqueeze(-1)
|
scale_factor = scale_factor.unsqueeze(-1)
|
||||||
sign_factor = sign_factor.unsqueeze(-1)
|
sign_factor = sign_factor.unsqueeze(-1)
|
||||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
|
||||||
else:
|
else:
|
||||||
xgt0, scale_factor = ctx.saved_tensors
|
xgtmean, scale_factor = ctx.saved_tensors
|
||||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
||||||
scale_factor = scale_factor.unsqueeze(-1)
|
scale_factor = scale_factor.unsqueeze(-1)
|
||||||
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
factor = scale_factor * (xgtmean.to(x_grad.dtype) - 0.5)
|
||||||
neg_delta_grad = x_grad.abs() * factor
|
neg_delta_grad = x_grad.abs() * factor
|
||||||
return x_grad - neg_delta_grad, None, None, None,
|
return x_grad - neg_delta_grad, None, None, None, None
|
||||||
|
|
||||||
def _compute_scale_factor(x: Tensor,
|
def _compute_scale_factor(x: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
min_abs: float,
|
min_abs: float,
|
||||||
max_abs: float,
|
max_abs: float,
|
||||||
gain_factor: float,
|
gain_factor: float,
|
||||||
max_factor: float) -> Tensor:
|
max_factor: float) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Computes a factor used in ActivationBalancer, that dictates how much we penalize (or anti-penalize)
|
||||||
|
the scale on the features.
|
||||||
|
|
||||||
|
Returns: (scale_factor, mean)
|
||||||
|
dim.
|
||||||
|
scale_factor: can be positive or negative, between -max_factor and max_factor; dictates
|
||||||
|
penalty or anti-penalty. It is of shape (num_channels,)
|
||||||
|
mean: mean per channel that we use for purposes of scale_factor; actually is clamped to
|
||||||
|
-min_abs..min_abs. Its like (1, num_channels, 1, 1) depending on the shape of x and
|
||||||
|
channel-dim.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
|
||||||
|
x_mean = torch.mean(x, dim=sum_dims, keepdim=True).to(torch.float32)
|
||||||
|
# the idea is that for purposes of applying max_abs, we regress effectively
|
||||||
|
# toward zero (assuming min_abs is much less than max_abs).
|
||||||
|
x_mean = x_mean.clamp(min=-min_abs, max=min_abs)
|
||||||
|
x_abs_mean = torch.mean((x - x_mean).abs(), dim=sum_dims).to(torch.float32)
|
||||||
|
|
||||||
if min_abs == 0.0:
|
if min_abs == 0.0:
|
||||||
below_threshold = 0.0
|
below_threshold = 0.0
|
||||||
@ -157,7 +177,7 @@ def _compute_scale_factor(x: Tensor,
|
|||||||
|
|
||||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||||
|
|
||||||
return below_threshold - above_threshold
|
return below_threshold - above_threshold, x_mean
|
||||||
|
|
||||||
def _compute_sign_factor(x: Tensor,
|
def _compute_sign_factor(x: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
@ -679,13 +699,13 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
sign_factor = None
|
sign_factor = None
|
||||||
|
|
||||||
|
|
||||||
scale_factor = _compute_scale_factor(x, self.channel_dim,
|
scale_factor, mean = _compute_scale_factor(x, self.channel_dim,
|
||||||
min_abs=float(self.min_abs),
|
min_abs=float(self.min_abs),
|
||||||
max_abs=float(self.max_abs),
|
max_abs=float(self.max_abs),
|
||||||
gain_factor=float(self.scale_gain_factor) / prob,
|
gain_factor=float(self.scale_gain_factor) / prob,
|
||||||
max_factor=float(self.max_factor))
|
max_factor=float(self.max_factor))
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x, scale_factor, sign_factor, self.channel_dim,
|
x, scale_factor, mean, sign_factor, self.channel_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user