From 6c26754628d885013b3d08e5b646b57e1a43f64f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 27 Apr 2023 22:35:26 +0800 Subject: [PATCH] Fix tests, make SwooshL and SwooshR more efficient in forward pass. --- .../pruned_transducer_stateless7/scaling.py | 46 +++++++++---------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 6e6edf5f4..9b6c8880b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1286,7 +1286,11 @@ class SwooshL(torch.nn.Module): if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - return SwooshLFunction.apply(x) + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + #return SwooshLFunction.apply(x) class SwooshRFunction(torch.autograd.Function): @@ -1294,6 +1298,7 @@ class SwooshRFunction(torch.autograd.Function): swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 derivatives are between -0.08 and 0.92. + """ @staticmethod @@ -1325,7 +1330,6 @@ class SwooshRFunction(torch.autograd.Function): # for self-testing only. assert d_scaled.min() >= 0.0 assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): @@ -1344,12 +1348,16 @@ class SwooshRFunction(torch.autograd.Function): class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. + """Return Swoosh-R activation. """ if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - return SwooshRFunction.apply(x) + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) # simple version of SwooshL that does not redefine the backprop, used in @@ -1605,20 +1613,6 @@ def _test_balancer_magnitude(): -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - def _test_double_swish_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 @@ -1639,8 +1633,9 @@ def _test_swooshl_deriv(): x.requires_grad = True m = SwooshL() + tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 @@ -1653,7 +1648,7 @@ def _test_swooshr_deriv(): m = SwooshR() tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 @@ -1733,8 +1728,10 @@ def _test_activation_dropout_and_linear(): for bias in [True, False]: # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because - for dropout_p in [0.0, 0.1]: + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: for activation in ['SwooshL', 'SwooshR']: m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), Dropout3(p=dropout_p, shared_dim=-1), @@ -1796,8 +1793,7 @@ if __name__ == "__main__": _test_whiten() _test_balancer_sign() _test_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() - _test_swooshr_deriv() _test_swooshl_deriv() + _test_swooshr_deriv() _test_activation_dropout_and_linear() + _test_double_swish_deriv()