mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix deriv code
This commit is contained in:
parent
c57eaf7979
commit
12fb2081b1
@ -1217,8 +1217,6 @@ class TanSwish(torch.nn.Module):
|
|||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||||
|
|
||||||
derivatives are between -0.08 and 0.92.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1231,19 +1229,21 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
coeff = -0.08
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
|
||||||
|
|
||||||
if not requires_grad:
|
if not requires_grad:
|
||||||
return y
|
return y
|
||||||
y.backward(gradient = torch.ones_like(y))
|
y.backward(gradient = torch.ones_like(y))
|
||||||
|
|
||||||
grad = x.grad
|
grad = x.grad
|
||||||
floor = -0.1
|
floor = coeff
|
||||||
ceil = 0.905 # real ceil would be 0.09, give it extra room for roundoff.
|
ceil = 1.0 + coeff + 0.005
|
||||||
|
|
||||||
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -1261,8 +1261,10 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
d, = ctx.saved_tensors
|
d, = ctx.saved_tensors
|
||||||
# the same constants as used in forward pass.
|
# the same constants as used in forward pass.
|
||||||
floor = -0.1
|
|
||||||
ceil = 0.905
|
coeff = -0.08
|
||||||
|
floor = coeff
|
||||||
|
ceil = 1.0 + coeff + 0.005
|
||||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||||
return (y_grad * d)
|
return (y_grad * d)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user