From ec10573edcfd1aa71c2fd1d6e1bfa62a4a760bc0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 2 Dec 2022 16:34:53 +0800 Subject: [PATCH] First version of swoosh --- .../pruned_transducer_stateless7/scaling.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 54c597ebc..e594e4e4d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1212,6 +1212,72 @@ class TanSwish(torch.nn.Module): return TanSwishFunction.apply(x) + + +class SwooshFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-4)) - 0.055*x - 0.15 + + derivatives are between -0.055 and 1-0.055. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + one = torch.tensor(1.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(one, x - 4) - 0.055 * x - 0.15 + + if not requires_grad: + return y + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = -0.055 + ceil = 0.946 # real ceil would be 0.0945, give it extra room for roundoff. + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.055 + ceil = 0.946 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class Swoosh(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return tan-swish activation function which is tanh(x) sigmoid(x-1)n + """ + if torch.jit.is_scripting(): + one = torch.tensor(1.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(one, x - 4) - 0.055 * x - 0.15 + return SwooshFunction.apply(x) + + + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1368,6 +1434,19 @@ def _test_tan_swish_deriv(): x.requires_grad = True y = m(x) +def _test_swoosh_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = Swoosh() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + def _test_softmax(): @@ -1395,3 +1474,4 @@ if __name__ == "__main__": _test_basic_norm() _test_double_swish_deriv() _test_tan_swish_deriv() + _test_swoosh_deriv()