mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Modify scaling.py to prevent constant values
This commit is contained in:
parent
633cbd551a
commit
a55f8c9c14
@ -58,9 +58,10 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
xgt0 = x > 0
|
||||
x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True)
|
||||
xgtmean = (x_normalized > 0)
|
||||
proportion_positive = torch.mean(
|
||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||
(x > 0).to(x.dtype), dim=sum_dims, keepdim=True
|
||||
)
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive).relu()
|
||||
@ -74,16 +75,24 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
if max_positive != 1.0
|
||||
else 0.0
|
||||
)
|
||||
# `factor` is a tensor of shape something like (1, 1, num_channels,
|
||||
# 1), containing elements between -1 and 1 that are zero if the
|
||||
# proportion of positive features is between min_positive and
|
||||
# max_positive, max_factor if proportion==0.0 (all features are negative),
|
||||
# and -max_factor if proportion==1.0 (all features are positive). It is
|
||||
# an amount per channel by which we'll modify the gradient; the sign
|
||||
# of modifying the gradient will depend on the sign of the gradient.
|
||||
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = mean_abs < min_abs
|
||||
above_threshold = mean_abs > max_abs
|
||||
|
||||
ctx.save_for_backward(
|
||||
factor, xgt0, below_threshold, above_threshold
|
||||
factor, xgtmean, below_threshold, above_threshold
|
||||
)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
@ -93,11 +102,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = (
|
||||
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||
* (xgt0.to(dtype) - 0.5)
|
||||
* (xgtmean.to(dtype) - 0.5)
|
||||
* (ctx.max_factor * 2.0)
|
||||
)
|
||||
|
||||
@ -105,6 +114,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
@ -404,12 +415,12 @@ class ActivationBalancer(torch.nn.Module):
|
||||
either the sign constraint or the magnitude constraint;
|
||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||
values in the range [0.98..1.02].
|
||||
min_abs: the minimum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
max_abs: the maximum average-absolute-value per channel, which
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
min_abs: the minimum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
max_abs: the maximum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
Loading…
x
Reference in New Issue
Block a user