diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ed3784a78..39b08e169 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -928,7 +928,6 @@ class DoubleSwishFunction(torch.autograd.Function): d = (d * ((ceil - floor) / 255.0) + floor) return (y_grad * d) - class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), @@ -939,6 +938,68 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) +class TanSwishFunction(torch.autograd.Function): + """ + double_swish(x) = tan(x) * torch.sigmoid(x-1) + + + entering: d/dx(tanh(x) * sigmoid(x-1)) + into wolfram alpha, I see that the range of this function is + -0.0498087 <= y <= 0.417894 + let's make it (as we don't know how this was rounded): + -0.0498088 <= y <= 0.417895 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if not requires_grad: + return torch.tanh(x) * torch.sigmoid(x - 1.0) + + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.tanh(x) * torch.sigmoid(x - 1.0) + y.backward(gradient=torch.ones_like(y)) + grad = x.grad + floor = -0.0498088 + ceil = 0.417895 + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.0498088 + ceil = 0.417895 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class TanSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return tan-swish activation function which is tanh(x) sigmoid(x-1)n + """ + if torch.jit.is_scripting(): + return x.tanh() * torch.sigmoid(x - 1.0) + return TanSwishFunction.apply(x) + + class ScheduledFloat(torch.nn.Module): """ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); @@ -1147,6 +1208,20 @@ def _test_double_swish_deriv(): x.requires_grad = True y = m(x) +def _test_tan_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = TanSwish() + + tol = ((1.2-(-0.043637))/255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + def _test_softmax(): @@ -1173,3 +1248,4 @@ if __name__ == "__main__": _test_activation_balancer_magnitude() _test_basic_norm() _test_double_swish_deriv() + _test_tan_swish_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7332a9eda..fa2bd448d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -29,6 +29,7 @@ from scaling import ( BasicNorm, MaxEig, DoubleSwish, + TanSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, @@ -1317,7 +1318,7 @@ class AttentionSqueeze(nn.Module): max_factor=0.02, min_prob=0.1, ) - self.bottleneck_activation = nn.Tanh() # in bottleneck + self.bottleneck_activation = TanSwish() # in bottleneck self.activation = Identity() # for diagnostics # the next two balancers are only to stop parameter-magnitude 'drift': we have