Add max-abs-value

This commit is contained in:
Daniel Povey 2022-03-13 13:15:20 +08:00
parent e6a501d3c8
commit 5d69acb25b
2 changed files with 33 additions and 25 deletions

View File

@ -47,14 +47,12 @@ class Conv2dSubsampling(nn.Module):
ScaledConv2d( ScaledConv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2 in_channels=1, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1, threshold=0.05, DerivBalancer(channel_dim=1),
max_factor=0.01),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2 in_channels=odim, out_channels=odim, kernel_size=3, stride=2
), ),
DerivBalancer(channel_dim=1, threshold=0.05, DerivBalancer(channel_dim=1),
max_factor=0.01),
DoubleSwish(), DoubleSwish(),
) )
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@ -325,7 +323,8 @@ class DerivBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, def forward(ctx, x: Tensor,
channel_dim: int, channel_dim: int,
threshold: float, # e.g. 0.05 min_positive: float, # e.g. 0.05
max_positive: float, # e.g. 0.95
max_factor: float, # e.g. 0.01 max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2 min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 1000.0 max_abs: float, # e.g. 1000.0
@ -336,7 +335,13 @@ class DerivBalancerFunction(torch.autograd.Function):
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0 xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold) factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
if min_positive != 0.0 else 0.0)
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
if max_positive != 1.0 else 0.0)
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs) below_threshold = (mean_abs < min_abs)
@ -348,16 +353,14 @@ class DerivBalancerFunction(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
neg_delta_grad = x_grad.abs() * (factor + scale_factor) neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None
return x_grad - neg_delta_grad, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -516,7 +519,9 @@ class DerivBalancer(torch.nn.Module):
Args: Args:
channel_dim: the dimension/axi corresponding to the channel, e.g. channel_dim: the dimension/axi corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
threshold: the threshold, per channel, of the proportion of the time min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives. that (x > 0), below which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives, max_factor: the maximum factor by which we modify the derivatives,
e.g. with max_factor=0.02, the the derivatives would be multiplied by e.g. with max_factor=0.02, the the derivatives would be multiplied by
@ -538,19 +543,22 @@ class DerivBalancer(torch.nn.Module):
out of floating point numerical range (especially in half precision). out of floating point numerical range (especially in half precision).
""" """
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
threshold: float = 0.05, min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 1000.0): max_abs: float = 1000.0):
super(DerivBalancer, self).__init__() super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.threshold = threshold self.min_positive = min_positive
self.max_positive = max_positive
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, return DerivBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs, self.max_factor, self.min_abs,
self.max_abs) self.max_abs)
@ -600,14 +608,14 @@ def _test_exp_scale_relu():
def _test_deriv_balancer_sign(): def _test_deriv_balancer_sign():
channel_dim = 0 channel_dim = 0
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 500 N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
y_grad[-1,:] = 0
y = m(x) y = m(x)
y.backward(gradient=y_grad) y.backward(gradient=y_grad)
@ -618,14 +626,16 @@ def _test_deriv_balancer_sign():
def _test_deriv_balancer_magnitude(): def _test_deriv_balancer_magnitude():
channel_dim = 0 channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01) magnitudes = torch.arange(0, 1, 0.01)
N = 500 N = 1000
x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) m = DerivBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y_grad[-1,:] = 0
y = m(x) y = m(x)
y.backward(gradient=y_grad) y.backward(gradient=y_grad)

View File

@ -158,8 +158,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1),
max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
@ -167,8 +166,7 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1),
max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),