Fix deriv code

This commit is contained in:
Daniel Povey 2022-12-04 21:22:06 +08:00
parent c57eaf7979
commit 12fb2081b1

View File

@ -1217,8 +1217,6 @@ class TanSwish(torch.nn.Module):
class SwooshLFunction(torch.autograd.Function):
"""
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
derivatives are between -0.08 and 0.92.
"""
@staticmethod
@ -1231,19 +1229,21 @@ class SwooshLFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
coeff = -0.08
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
if not requires_grad:
return y
y.backward(gradient = torch.ones_like(y))
grad = x.grad
floor = -0.1
ceil = 0.905 # real ceil would be 0.09, give it extra room for roundoff.
floor = coeff
ceil = 1.0 + coeff + 0.005
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
if __name__ == "__main__":
@ -1261,8 +1261,10 @@ class SwooshLFunction(torch.autograd.Function):
def backward(ctx, y_grad: Tensor) -> Tensor:
d, = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.1
ceil = 0.905
coeff = -0.08
floor = coeff
ceil = 1.0 + coeff + 0.005
d = (d * ((ceil - floor) / 255.0) + floor)
return (y_grad * d)