mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
More memory efficient backprop for DoubleSwish.
This commit is contained in:
parent
95aaa4a8d2
commit
36cb279318
@ -915,22 +915,40 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
x = x.detach()
|
||||
x_dtype = x.dtype
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
y = x * s
|
||||
|
||||
if requires_grad:
|
||||
# discretize s. This should be expectation-preserving if we just divide the
|
||||
# result by 255.
|
||||
s = ((s * 255) + torch.rand_like(s)).clamp(max=255).to(torch.uint8)
|
||||
ctx.save_for_backward(s, y)
|
||||
deriv = (y * (1 - s) + s)
|
||||
# notes on derivative of x * sigmoid(x - 1):
|
||||
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
||||
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
||||
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
||||
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv))
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
assert d_scaled.max() < 256.0
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
s, y = ctx.saved_tensors
|
||||
# converts back to float.
|
||||
s = s.to(y_grad.dtype) * (1.0 / 255)
|
||||
return (y * (1 - s) + s) * y_grad
|
||||
d, = ctx.saved_tensors
|
||||
# the same constants as used in forward pass.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||
return (y_grad * d)
|
||||
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
@ -1073,10 +1091,19 @@ def _test_basic_norm():
|
||||
|
||||
|
||||
def _test_double_swish_deriv():
|
||||
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
||||
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
||||
x.requires_grad = True
|
||||
m = DoubleSwish()
|
||||
torch.autograd.gradcheck(m, x, atol=0.02)
|
||||
|
||||
tol = ((1.2-(-0.043637))/255.0)
|
||||
torch.autograd.gradcheck(m, x, atol=tol)
|
||||
|
||||
|
||||
# for self-test.
|
||||
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
|
||||
|
||||
|
||||
def _test_softmax():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user