mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add min-abs-value 0.2
This commit is contained in:
parent
2fa9c636a4
commit
76560f255c
@ -312,33 +312,36 @@ class ExpScaleRelu(torch.nn.Module):
|
|||||||
|
|
||||||
class DerivBalancerFunction(torch.autograd.Function):
|
class DerivBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, channel_dim: int,
|
def forward(ctx, x: Tensor,
|
||||||
|
channel_dim: int,
|
||||||
threshold: float = 0.05,
|
threshold: float = 0.05,
|
||||||
max_factor: float = 0.05,
|
max_factor: float = 0.05,
|
||||||
zero: float = 0.02,
|
min_abs: float = 0.2) -> Tensor:
|
||||||
epsilon: float = 1.0e-10) -> Tensor:
|
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True)
|
xgt0 = x > 0
|
||||||
|
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||||
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
factor = (threshold - proportion_positive).relu() * (max_factor / threshold)
|
||||||
|
|
||||||
ctx.save_for_backward(factor)
|
below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs)
|
||||||
ctx.epsilon = epsilon
|
|
||||||
|
ctx.save_for_backward(factor, xgt0, below_threshold)
|
||||||
|
ctx.max_factor = max_factor
|
||||||
ctx.sum_dims = sum_dims
|
ctx.sum_dims = sum_dims
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
||||||
factor, = ctx.saved_tensors
|
factor, xgt0, below_threshold = ctx.saved_tensors
|
||||||
neg_delta_grad = x_grad.abs() * factor
|
dtype = x_grad.dtype
|
||||||
if ctx.epsilon != 0.0:
|
too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)
|
||||||
sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True)
|
|
||||||
deriv_is_zero = (sum_abs_grad == 0.0)
|
|
||||||
neg_delta_grad += ctx.epsilon * deriv_is_zero
|
|
||||||
|
|
||||||
return x_grad - neg_delta_grad, None, None, None, None, None
|
neg_delta_grad = x_grad.abs() * (factor + too_small_factor)
|
||||||
|
|
||||||
|
|
||||||
|
return x_grad - neg_delta_grad, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
def __init__(self, channel_dim: int,
|
def __init__(self, channel_dim: int,
|
||||||
threshold: float = 0.05,
|
threshold: float = 0.05,
|
||||||
max_factor: float = 0.02,
|
max_factor: float = 0.02,
|
||||||
zero: float = 0.02,
|
min_abs: float = 0.2):
|
||||||
epsilon: float = 1.0e-10):
|
|
||||||
super(DerivBalancer, self).__init__()
|
super(DerivBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.zero = zero
|
self.min_abs = min_abs
|
||||||
self.epsilon = epsilon
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold,
|
||||||
self.max_factor, self.zero,
|
self.max_factor, self.min_abs)
|
||||||
self.epsilon)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -505,23 +506,41 @@ def _test_exp_scale_relu():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_deriv_balancer():
|
def _test_deriv_balancer_sign():
|
||||||
channel_dim = 0
|
channel_dim = 0
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
N = 500
|
N = 500
|
||||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10)
|
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
y_grad[-1,:] = 0
|
y_grad[-1,:] = 0
|
||||||
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
y.backward(gradient=y_grad)
|
y.backward(gradient=y_grad)
|
||||||
print("x = ", x)
|
print("_test_deriv_balancer_sign: x = ", x)
|
||||||
print("y grad = ", y_grad)
|
print("_test_deriv_balancer_sign: y grad = ", y_grad)
|
||||||
print("x grad = ", x.grad)
|
print("_test_deriv_balancer_sign: x grad = ", x.grad)
|
||||||
|
|
||||||
|
def _test_deriv_balancer_magnitude():
|
||||||
|
channel_dim = 0
|
||||||
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
|
N = 500
|
||||||
|
x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1))
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2)
|
||||||
|
|
||||||
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
y_grad[-1,:] = 0
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
print("_test_deriv_balancer_magnitude: x = ", x)
|
||||||
|
print("_test_deriv_balancer_magnitude: y grad = ", y_grad)
|
||||||
|
print("_test_deriv_balancer_magnitude: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
def _test_basic_norm():
|
def _test_basic_norm():
|
||||||
@ -543,7 +562,8 @@ def _test_basic_norm():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_deriv_balancer()
|
_test_deriv_balancer_sign()
|
||||||
|
_test_deriv_balancer_magnitude()
|
||||||
_test_exp_scale_swish()
|
_test_exp_scale_swish()
|
||||||
_test_exp_scale_relu()
|
_test_exp_scale_relu()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02",
|
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user