mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Modify scaling.py to prevent constant values
This commit is contained in:
parent
3c1fddaf48
commit
9d7af4be20
@ -58,9 +58,10 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
xgt0 = x > 0
|
x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True)
|
||||||
|
xgtmean = (x_normalized > 0)
|
||||||
proportion_positive = torch.mean(
|
proportion_positive = torch.mean(
|
||||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
(x > 0).to(x.dtype), dim=sum_dims, keepdim=True
|
||||||
)
|
)
|
||||||
factor1 = (
|
factor1 = (
|
||||||
(min_positive - proportion_positive).relu()
|
(min_positive - proportion_positive).relu()
|
||||||
@ -74,16 +75,24 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
if max_positive != 1.0
|
if max_positive != 1.0
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
# `factor` is a tensor of shape something like (1, 1, num_channels,
|
||||||
|
# 1), containing elements between -1 and 1 that are zero if the
|
||||||
|
# proportion of positive features is between min_positive and
|
||||||
|
# max_positive, max_factor if proportion==0.0 (all features are negative),
|
||||||
|
# and -max_factor if proportion==1.0 (all features are positive). It is
|
||||||
|
# an amount per channel by which we'll modify the gradient; the sign
|
||||||
|
# of modifying the gradient will depend on the sign of the gradient.
|
||||||
|
|
||||||
factor = factor1 + factor2
|
factor = factor1 + factor2
|
||||||
if isinstance(factor, float):
|
if isinstance(factor, float):
|
||||||
factor = torch.zeros_like(proportion_positive)
|
factor = torch.zeros_like(proportion_positive)
|
||||||
|
|
||||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True)
|
||||||
below_threshold = mean_abs < min_abs
|
below_threshold = mean_abs < min_abs
|
||||||
above_threshold = mean_abs > max_abs
|
above_threshold = mean_abs > max_abs
|
||||||
|
|
||||||
ctx.save_for_backward(
|
ctx.save_for_backward(
|
||||||
factor, xgt0, below_threshold, above_threshold
|
factor, xgtmean, below_threshold, above_threshold
|
||||||
)
|
)
|
||||||
ctx.max_factor = max_factor
|
ctx.max_factor = max_factor
|
||||||
ctx.sum_dims = sum_dims
|
ctx.sum_dims = sum_dims
|
||||||
@ -93,11 +102,11 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
def backward(
|
def backward(
|
||||||
ctx, x_grad: Tensor
|
ctx, x_grad: Tensor
|
||||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors
|
||||||
dtype = x_grad.dtype
|
dtype = x_grad.dtype
|
||||||
scale_factor = (
|
scale_factor = (
|
||||||
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||||
* (xgt0.to(dtype) - 0.5)
|
* (xgtmean.to(dtype) - 0.5)
|
||||||
* (ctx.max_factor * 2.0)
|
* (ctx.max_factor * 2.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,6 +114,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||||
@ -404,12 +415,12 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
either the sign constraint or the magnitude constraint;
|
either the sign constraint or the magnitude constraint;
|
||||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||||
values in the range [0.98..1.02].
|
values in the range [0.98..1.02].
|
||||||
min_abs: the minimum average-absolute-value per channel, which
|
min_abs: the minimum average-absolute-value difference from the mean
|
||||||
we allow, before we start to modify the derivatives to prevent
|
value per channel, which we allow, before we start to modify
|
||||||
this.
|
the derivatives to prevent this.
|
||||||
max_abs: the maximum average-absolute-value per channel, which
|
max_abs: the maximum average-absolute-value difference from the mean
|
||||||
we allow, before we start to modify the derivatives to prevent
|
value per channel, which we allow, before we start to modify
|
||||||
this.
|
the derivatives to prevent this.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user