From 7b1f09307753d69cd9ee5594dec89d8dfe2bbae6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 4 Dec 2022 19:18:16 +0800 Subject: [PATCH] Use Swoosh-R in the Conv and Swoosh-L in the feedforward. --- .../pruned_transducer_stateless7/scaling.py | 91 +++++++++++++++++-- .../pruned_transducer_stateless7/zipformer.py | 9 +- 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 484eb24ef..32bfab990 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1214,7 +1214,7 @@ class TanSwish(torch.nn.Module): -class SwooshFunction(torch.autograd.Function): +class SwooshLFunction(torch.autograd.Function): """ swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 @@ -1267,14 +1267,77 @@ class SwooshFunction(torch.autograd.Function): return (y_grad * d) -class Swoosh(torch.nn.Module): +class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return tan-swish activation function which is tanh(x) sigmoid(x-1)n + """Return Swoosh-L activation. """ if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 - return SwooshFunction.apply(x) + return SwooshLFunction.apply(x) + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + if torch.jit.is_scripting(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return SwooshRFunction.apply(x) @@ -1434,10 +1497,23 @@ def _test_tan_swish_deriv(): x.requires_grad = True y = m(x) -def _test_swoosh_deriv(): +def _test_swooshl_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True - m = Swoosh() + m = SwooshL() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() tol = (1.0 / 255.0) torch.autograd.gradcheck(m, x, atol=tol) @@ -1474,4 +1550,5 @@ if __name__ == "__main__": _test_basic_norm() _test_double_swish_deriv() _test_tan_swish_deriv() - _test_swoosh_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6eb8182c4..038da0136 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -29,7 +29,8 @@ from scaling import ( BasicNorm, MaxEig, DoubleSwish, - Swoosh, + SwooshL, + SwooshR, TanSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. @@ -1426,7 +1427,7 @@ class FeedforwardModule(nn.Module): min_abs=1.0, max_abs=5.0, min_prob=0.25) - self.activation = Swoosh() + self.activation = SwooshL() self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, @@ -1601,11 +1602,11 @@ class ConvolutionModule(nn.Module): channels, channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, - min_abs=0.75, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)), max_abs=10.0, ) - self.activation = nn.Tanh() + self.activation = SwooshR() self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5),