From e586cc319c26ced6ad6c27b262328e54589231da Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Oct 2022 17:11:35 +0800 Subject: [PATCH] Change the discretization of the sigmoid to be expectation preserving. --- .../ASR/pruned_transducer_stateless7/scaling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 3d1cd5f56..e31548737 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -918,9 +918,9 @@ class DoubleSwishFunction(torch.autograd.Function): x = x.detach() s = torch.sigmoid(x - 1.0) y = x * s - # discretize s. Note: .to(torch.uint8) rounds down. We'll correct for this - # in an amortized way by adding 0.5 when we convert back to float. - s = (s * 255.999).to(torch.uint8) + # discretize s. This should be expectation-preserving if we just divide the + # result by 255. + s = ((s * 255) + torch.randn_like(s)).to(torch.uint8) ctx.save_for_backward(s, y) return y @@ -928,7 +928,7 @@ class DoubleSwishFunction(torch.autograd.Function): def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors # converts back to float. - s = (s.to(y_grad.dtype) + 0.5) * (1.0 / 255.999) + s = s.to(y_grad.dtype) * (1.0 / 255) return (y * (1 - s) + s) * y_grad @@ -1075,7 +1075,7 @@ def _test_double_swish_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 0.5 x.requires_grad = True m = DoubleSwish() - torch.autograd.gradcheck(m, x, atol=0.01) + torch.autograd.gradcheck(m, x, atol=0.02) def _test_softmax():