Add min-abs-value 0.2

This commit is contained in:
Daniel Povey 2022-03-10 23:48:46 +08:00
parent 2fa9c636a4
commit 76560f255c
2 changed files with 47 additions and 27 deletions

View File

@ -312,33 +312,36 @@ class ExpScaleRelu(torch.nn.Module):
class DerivBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, channel_dim: int,
def forward(ctx, x: Tensor,
channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.05,
zero: float = 0.02,
epsilon: float = 1.0e-10) -> Tensor:
min_abs: float = 0.2) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True)
xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
ctx.save_for_backward(factor)
ctx.epsilon = epsilon
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
ctx.save_for_backward(factor, xgt0, below_threshold)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
factor, = ctx.saved_tensors
neg_delta_grad = x_grad.abs() * factor
if ctx.epsilon != 0.0:
sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True)
deriv_is_zero = (sum_abs_grad == 0.0)
neg_delta_grad += ctx.epsilon * deriv_is_zero
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
factor, xgt0, below_threshold = ctx.saved_tensors
dtype = x_grad.dtype
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
return x_grad - neg_delta_grad, None, None, None, None, None
neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
return x_grad - neg_delta_grad, None, None, None, None
class BasicNorm(torch.nn.Module):
@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module):
def __init__(self, channel_dim: int,
threshold: float = 0.05,
max_factor: float = 0.02,
zero: float = 0.02,
epsilon: float = 1.0e-10):
min_abs: float = 0.2):
super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim
self.threshold = threshold
self.max_factor = max_factor
self.zero = zero
self.epsilon = epsilon
self.min_abs = min_abs
def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.zero,
self.epsilon)
self.max_factor, self.min_abs)
@ -505,23 +506,41 @@ def _test_exp_scale_relu():
def _test_deriv_balancer():
def _test_deriv_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01)
N = 500
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10)
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
y_grad = torch.sign(torch.randn(probs.numel(), N))
y_grad[-1,:] = 0
y = m(x)
y.backward(gradient=y_grad)
print("x = ", x)
print("y grad = ", y_grad)
print("x grad = ", x.grad)
print("_test_deriv_balancer_sign: x = ", x)
print("_test_deriv_balancer_sign: y grad = ", y_grad)
print("_test_deriv_balancer_sign: x grad = ", x.grad)
def _test_deriv_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01)
N = 500
x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y_grad[-1,:] = 0
y = m(x)
y.backward(gradient=y_grad)
print("_test_deriv_balancer_magnitude: x = ", x)
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
def _test_basic_norm():
@ -543,7 +562,8 @@ def _test_basic_norm():
if __name__ == '__main__':
_test_deriv_balancer()
_test_deriv_balancer_sign()
_test_deriv_balancer_magnitude()
_test_exp_scale_swish()
_test_exp_scale_relu()
_test_basic_norm()

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02",
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved