Add min-abs-value 0.2

This commit is contained in:
Daniel Povey 2022-03-10 23:48:46 +08:00
parent 2fa9c636a4
commit 76560f255c
2 changed files with 47 additions and 27 deletions

View File

@ -312,33 +312,36 @@ class ExpScaleRelu(torch.nn.Module):
class DerivBalancerFunction(torch.autograd.Function): class DerivBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, channel_dim: int, def forward(ctx, x: Tensor,
channel_dim: int,
threshold: float = 0.05, threshold: float = 0.05,
max_factor: float = 0.05, max_factor: float = 0.05,
zero: float = 0.02, min_abs: float = 0.2) -> Tensor:
epsilon: float = 1.0e-10) -> Tensor:
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor = (threshold - proportion_positive).relu() * (max_factor / threshold) factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
ctx.save_for_backward(factor) below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
ctx.epsilon = epsilon
ctx.save_for_backward(factor, xgt0, below_threshold)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims ctx.sum_dims = sum_dims
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
factor, = ctx.saved_tensors factor, xgt0, below_threshold = ctx.saved_tensors
neg_delta_grad = x_grad.abs() * factor dtype = x_grad.dtype
if ctx.epsilon != 0.0: too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True)
deriv_is_zero = (sum_abs_grad == 0.0)
neg_delta_grad += ctx.epsilon * deriv_is_zero
return x_grad - neg_delta_grad, None, None, None, None, None neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
return x_grad - neg_delta_grad, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module):
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
threshold: float = 0.05, threshold: float = 0.05,
max_factor: float = 0.02, max_factor: float = 0.02,
zero: float = 0.02, min_abs: float = 0.2):
epsilon: float = 1.0e-10):
super(DerivBalancer, self).__init__() super(DerivBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.threshold = threshold self.threshold = threshold
self.max_factor = max_factor self.max_factor = max_factor
self.zero = zero self.min_abs = min_abs
self.epsilon = epsilon
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
self.max_factor, self.zero, self.max_factor, self.min_abs)
self.epsilon)
@ -505,23 +506,41 @@ def _test_exp_scale_relu():
def _test_deriv_balancer(): def _test_deriv_balancer_sign():
channel_dim = 0 channel_dim = 0
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 500 N = 500
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
y_grad[-1,:] = 0 y_grad[-1,:] = 0
y = m(x) y = m(x)
y.backward(gradient=y_grad) y.backward(gradient=y_grad)
print("x = ", x) print("_test_deriv_balancer_sign: x = ", x)
print("y grad = ", y_grad) print("_test_deriv_balancer_sign: y grad = ", y_grad)
print("x grad = ", x.grad) print("_test_deriv_balancer_sign: x grad = ", x.grad)
def _test_deriv_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01)
N = 500
x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
y_grad[-1,:] = 0
y = m(x)
y.backward(gradient=y_grad)
print("_test_deriv_balancer_magnitude: x = ", x)
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
def _test_basic_norm(): def _test_basic_norm():
@ -543,7 +562,8 @@ def _test_basic_norm():
if __name__ == '__main__': if __name__ == '__main__':
_test_deriv_balancer() _test_deriv_balancer_sign()
_test_deriv_balancer_magnitude()
_test_exp_scale_swish() _test_exp_scale_swish()
_test_exp_scale_relu() _test_exp_scale_relu()
_test_basic_norm() _test_basic_norm()

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved