mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix tests, make SwooshL and SwooshR more efficient in forward pass.
This commit is contained in:
parent
55a1abc9da
commit
6c26754628
@ -1286,7 +1286,11 @@ class SwooshL(torch.nn.Module):
|
|||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
return SwooshLFunction.apply(x)
|
if not x.requires_grad:
|
||||||
|
return k2.swoosh_l_forward(x)
|
||||||
|
else:
|
||||||
|
return k2.swoosh_l(x)
|
||||||
|
#return SwooshLFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class SwooshRFunction(torch.autograd.Function):
|
class SwooshRFunction(torch.autograd.Function):
|
||||||
@ -1294,6 +1298,7 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||||
|
|
||||||
derivatives are between -0.08 and 0.92.
|
derivatives are between -0.08 and 0.92.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1325,7 +1330,6 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
# for self-testing only.
|
# for self-testing only.
|
||||||
assert d_scaled.min() >= 0.0
|
assert d_scaled.min() >= 0.0
|
||||||
assert d_scaled.max() < 256.0
|
assert d_scaled.max() < 256.0
|
||||||
|
|
||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
@ -1344,12 +1348,16 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
class SwooshR(torch.nn.Module):
|
class SwooshR(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-L activation.
|
"""Return Swoosh-R activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
return SwooshRFunction.apply(x)
|
if not x.requires_grad:
|
||||||
|
return k2.swoosh_r_forward(x)
|
||||||
|
else:
|
||||||
|
return k2.swoosh_r(x)
|
||||||
|
# return SwooshRFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
# simple version of SwooshL that does not redefine the backprop, used in
|
# simple version of SwooshL that does not redefine the backprop, used in
|
||||||
@ -1605,20 +1613,6 @@ def _test_balancer_magnitude():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_basic_norm():
|
|
||||||
num_channels = 128
|
|
||||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
|
||||||
|
|
||||||
x = torch.randn(500, num_channels)
|
|
||||||
|
|
||||||
y = m(x)
|
|
||||||
|
|
||||||
assert y.shape == x.shape
|
|
||||||
x_rms = (x ** 2).mean().sqrt()
|
|
||||||
y_rms = (y ** 2).mean().sqrt()
|
|
||||||
print("x rms = ", x_rms)
|
|
||||||
print("y rms = ", y_rms)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_double_swish_deriv():
|
def _test_double_swish_deriv():
|
||||||
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
||||||
@ -1639,8 +1633,9 @@ def _test_swooshl_deriv():
|
|||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = SwooshL()
|
m = SwooshL()
|
||||||
|
|
||||||
|
|
||||||
tol = (1.0 / 255.0)
|
tol = (1.0 / 255.0)
|
||||||
torch.autograd.gradcheck(m, x, atol=tol)
|
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
||||||
|
|
||||||
# for self-test.
|
# for self-test.
|
||||||
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
||||||
@ -1653,7 +1648,7 @@ def _test_swooshr_deriv():
|
|||||||
m = SwooshR()
|
m = SwooshR()
|
||||||
|
|
||||||
tol = (1.0 / 255.0)
|
tol = (1.0 / 255.0)
|
||||||
torch.autograd.gradcheck(m, x, atol=tol)
|
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
||||||
|
|
||||||
# for self-test.
|
# for self-test.
|
||||||
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
||||||
@ -1733,8 +1728,10 @@ def _test_activation_dropout_and_linear():
|
|||||||
|
|
||||||
for bias in [True, False]:
|
for bias in [True, False]:
|
||||||
# actually we don't test for dropout_p != 0.0 because forward functions will give
|
# actually we don't test for dropout_p != 0.0 because forward functions will give
|
||||||
# different answers. This is because
|
# different answers. This is because we are using the k2 implementation of
|
||||||
for dropout_p in [0.0, 0.1]:
|
# swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
|
||||||
|
# internally, messing up the random state.
|
||||||
|
for dropout_p in [0.0]:
|
||||||
for activation in ['SwooshL', 'SwooshR']:
|
for activation in ['SwooshL', 'SwooshR']:
|
||||||
m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(),
|
m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(),
|
||||||
Dropout3(p=dropout_p, shared_dim=-1),
|
Dropout3(p=dropout_p, shared_dim=-1),
|
||||||
@ -1796,8 +1793,7 @@ if __name__ == "__main__":
|
|||||||
_test_whiten()
|
_test_whiten()
|
||||||
_test_balancer_sign()
|
_test_balancer_sign()
|
||||||
_test_balancer_magnitude()
|
_test_balancer_magnitude()
|
||||||
_test_basic_norm()
|
|
||||||
_test_double_swish_deriv()
|
|
||||||
_test_swooshr_deriv()
|
|
||||||
_test_swooshl_deriv()
|
_test_swooshl_deriv()
|
||||||
|
_test_swooshr_deriv()
|
||||||
_test_activation_dropout_and_linear()
|
_test_activation_dropout_and_linear()
|
||||||
|
_test_double_swish_deriv()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user