Fix tests, make SwooshL and SwooshR more efficient in forward pass.

This commit is contained in:
Daniel Povey 2023-04-27 22:35:26 +08:00
parent 55a1abc9da
commit 6c26754628

View File

@ -1286,7 +1286,11 @@ class SwooshL(torch.nn.Module):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
return SwooshLFunction.apply(x) if not x.requires_grad:
return k2.swoosh_l_forward(x)
else:
return k2.swoosh_l(x)
#return SwooshLFunction.apply(x)
class SwooshRFunction(torch.autograd.Function): class SwooshRFunction(torch.autograd.Function):
@ -1294,6 +1298,7 @@ class SwooshRFunction(torch.autograd.Function):
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
derivatives are between -0.08 and 0.92. derivatives are between -0.08 and 0.92.
""" """
@staticmethod @staticmethod
@ -1325,7 +1330,6 @@ class SwooshRFunction(torch.autograd.Function):
# for self-testing only. # for self-testing only.
assert d_scaled.min() >= 0.0 assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0 assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8) d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int) ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled(): if x.dtype == torch.float16 or torch.is_autocast_enabled():
@ -1344,12 +1348,16 @@ class SwooshRFunction(torch.autograd.Function):
class SwooshR(torch.nn.Module): class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation. """Return Swoosh-R activation.
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return SwooshRFunction.apply(x) if not x.requires_grad:
return k2.swoosh_r_forward(x)
else:
return k2.swoosh_r(x)
# return SwooshRFunction.apply(x)
# simple version of SwooshL that does not redefine the backprop, used in # simple version of SwooshL that does not redefine the backprop, used in
@ -1605,20 +1613,6 @@ def _test_balancer_magnitude():
def _test_basic_norm():
num_channels = 128
m = BasicNorm(num_channels=num_channels, channel_dim=1)
x = torch.randn(500, num_channels)
y = m(x)
assert y.shape == x.shape
x_rms = (x ** 2).mean().sqrt()
y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
def _test_double_swish_deriv(): def _test_double_swish_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0 x = torch.randn(10, 12, dtype=torch.double) * 3.0
@ -1639,8 +1633,9 @@ def _test_swooshl_deriv():
x.requires_grad = True x.requires_grad = True
m = SwooshL() m = SwooshL()
tol = (1.0 / 255.0) tol = (1.0 / 255.0)
torch.autograd.gradcheck(m, x, atol=tol) torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
# for self-test. # for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
@ -1653,7 +1648,7 @@ def _test_swooshr_deriv():
m = SwooshR() m = SwooshR()
tol = (1.0 / 255.0) tol = (1.0 / 255.0)
torch.autograd.gradcheck(m, x, atol=tol) torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
# for self-test. # for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
@ -1733,8 +1728,10 @@ def _test_activation_dropout_and_linear():
for bias in [True, False]: for bias in [True, False]:
# actually we don't test for dropout_p != 0.0 because forward functions will give # actually we don't test for dropout_p != 0.0 because forward functions will give
# different answers. This is because # different answers. This is because we are using the k2 implementation of
for dropout_p in [0.0, 0.1]: # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
# internally, messing up the random state.
for dropout_p in [0.0]:
for activation in ['SwooshL', 'SwooshR']: for activation in ['SwooshL', 'SwooshR']:
m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(),
Dropout3(p=dropout_p, shared_dim=-1), Dropout3(p=dropout_p, shared_dim=-1),
@ -1796,8 +1793,7 @@ if __name__ == "__main__":
_test_whiten() _test_whiten()
_test_balancer_sign() _test_balancer_sign()
_test_balancer_magnitude() _test_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_swooshr_deriv()
_test_swooshl_deriv() _test_swooshl_deriv()
_test_swooshr_deriv()
_test_activation_dropout_and_linear() _test_activation_dropout_and_linear()
_test_double_swish_deriv()