mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add max-abs-value
This commit is contained in:
parent
e6a501d3c8
commit
5d69acb25b
@ -47,14 +47,12 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.05,
|
DerivBalancer(channel_dim=1),
|
||||||
max_factor=0.01),
|
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
DerivBalancer(channel_dim=1, threshold=0.05,
|
DerivBalancer(channel_dim=1),
|
||||||
max_factor=0.01),
|
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
@ -325,7 +323,8 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor,
|
def forward(ctx, x: Tensor,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
threshold: float, # e.g. 0.05
|
min_positive: float, # e.g. 0.05
|
||||||
|
max_positive: float, # e.g. 0.95
|
||||||
max_factor: float, # e.g. 0.01
|
max_factor: float, # e.g. 0.01
|
||||||
min_abs: float, # e.g. 0.2
|
min_abs: float, # e.g. 0.2
|
||||||
max_abs: float, # e.g. 1000.0
|
max_abs: float, # e.g. 1000.0
|
||||||
@ -336,7 +335,13 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
xgt0 = x > 0
|
xgt0 = x > 0
|
||||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||||
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
||||||
|
if min_positive != 0.0 else 0.0)
|
||||||
|
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
||||||
|
if max_positive != 1.0 else 0.0)
|
||||||
|
factor = factor1 + factor2
|
||||||
|
if isinstance(factor, float):
|
||||||
|
factor = torch.zeros_like(proportion_positive)
|
||||||
|
|
||||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||||
below_threshold = (mean_abs < min_abs)
|
below_threshold = (mean_abs < min_abs)
|
||||||
@ -348,16 +353,14 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||||
dtype = x_grad.dtype
|
dtype = x_grad.dtype
|
||||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||||
|
|
||||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||||
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
return x_grad - neg_delta_grad, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
@ -516,7 +519,9 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
channel_dim: the dimension/axi corresponding to the channel, e.g.
|
channel_dim: the dimension/axi corresponding to the channel, e.g.
|
||||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||||
threshold: the threshold, per channel, of the proportion of the time
|
min_positive: the minimum, per channel, of the proportion of the time
|
||||||
|
that (x > 0), below which we start to modify the derivatives.
|
||||||
|
max_positive: the maximum, per channel, of the proportion of the time
|
||||||
that (x > 0), below which we start to modify the derivatives.
|
that (x > 0), below which we start to modify the derivatives.
|
||||||
max_factor: the maximum factor by which we modify the derivatives,
|
max_factor: the maximum factor by which we modify the derivatives,
|
||||||
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
||||||
@ -538,19 +543,22 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
out of floating point numerical range (especially in half precision).
|
out of floating point numerical range (especially in half precision).
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
def __init__(self, channel_dim: int,
|
||||||
threshold: float = 0.05,
|
min_positive: float = 0.05,
|
||||||
|
max_positive: float = 0.95,
|
||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2,
|
min_abs: float = 0.2,
|
||||||
max_abs: float = 1000.0):
|
max_abs: float = 1000.0):
|
||||||
super(DerivBalancer, self).__init__()
|
super(DerivBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.threshold = threshold
|
self.min_positive = min_positive
|
||||||
|
self.max_positive = max_positive
|
||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.min_abs = min_abs
|
self.min_abs = min_abs
|
||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
return DerivBalancerFunction.apply(x, self.channel_dim,
|
||||||
|
self.min_positive, self.max_positive,
|
||||||
self.max_factor, self.min_abs,
|
self.max_factor, self.min_abs,
|
||||||
self.max_abs)
|
self.max_abs)
|
||||||
|
|
||||||
@ -600,14 +608,14 @@ def _test_exp_scale_relu():
|
|||||||
def _test_deriv_balancer_sign():
|
def _test_deriv_balancer_sign():
|
||||||
channel_dim = 0
|
channel_dim = 0
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
N = 500
|
N = 1000
|
||||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
|
m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||||
|
max_factor=0.2, min_abs=0.0)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
y_grad[-1,:] = 0
|
|
||||||
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
y.backward(gradient=y_grad)
|
y.backward(gradient=y_grad)
|
||||||
@ -618,14 +626,16 @@ def _test_deriv_balancer_sign():
|
|||||||
def _test_deriv_balancer_magnitude():
|
def _test_deriv_balancer_magnitude():
|
||||||
channel_dim = 0
|
channel_dim = 0
|
||||||
magnitudes = torch.arange(0, 1, 0.01)
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
N = 500
|
N = 1000
|
||||||
x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1))
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
|
m = DerivBalancer(channel_dim=0,
|
||||||
|
min_positive=0.0, max_positive=1.0,
|
||||||
|
max_factor=0.2,
|
||||||
|
min_abs=0.2, max_abs=0.8)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
y_grad[-1,:] = 0
|
|
||||||
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
y.backward(gradient=y_grad)
|
y.backward(gradient=y_grad)
|
||||||
|
@ -158,8 +158,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.05,
|
DerivBalancer(channel_dim=-1),
|
||||||
max_factor=0.01),
|
|
||||||
SwishExpScale(dim_feedforward, speed=20.0),
|
SwishExpScale(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
@ -167,8 +166,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
DerivBalancer(channel_dim=-1, threshold=0.05,
|
DerivBalancer(channel_dim=-1),
|
||||||
max_factor=0.01),
|
|
||||||
SwishExpScale(dim_feedforward, speed=20.0),
|
SwishExpScale(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user