mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change the discretization of the sigmoid to be expectation preserving.
This commit is contained in:
parent
09cbc9fdab
commit
e586cc319c
@ -918,9 +918,9 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
x = x.detach()
|
x = x.detach()
|
||||||
s = torch.sigmoid(x - 1.0)
|
s = torch.sigmoid(x - 1.0)
|
||||||
y = x * s
|
y = x * s
|
||||||
# discretize s. Note: .to(torch.uint8) rounds down. We'll correct for this
|
# discretize s. This should be expectation-preserving if we just divide the
|
||||||
# in an amortized way by adding 0.5 when we convert back to float.
|
# result by 255.
|
||||||
s = (s * 255.999).to(torch.uint8)
|
s = ((s * 255) + torch.randn_like(s)).to(torch.uint8)
|
||||||
ctx.save_for_backward(s, y)
|
ctx.save_for_backward(s, y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -928,7 +928,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
s, y = ctx.saved_tensors
|
s, y = ctx.saved_tensors
|
||||||
# converts back to float.
|
# converts back to float.
|
||||||
s = (s.to(y_grad.dtype) + 0.5) * (1.0 / 255.999)
|
s = s.to(y_grad.dtype) * (1.0 / 255)
|
||||||
return (y * (1 - s) + s) * y_grad
|
return (y * (1 - s) + s) * y_grad
|
||||||
|
|
||||||
|
|
||||||
@ -1075,7 +1075,7 @@ def _test_double_swish_deriv():
|
|||||||
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = DoubleSwish()
|
m = DoubleSwish()
|
||||||
torch.autograd.gradcheck(m, x, atol=0.01)
|
torch.autograd.gradcheck(m, x, atol=0.02)
|
||||||
|
|
||||||
|
|
||||||
def _test_softmax():
|
def _test_softmax():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user