mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Make DoubleSwish more memory efficient
This commit is contained in:
parent
f351777e9c
commit
ae25688253
@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|||||||
return (x * (scale * speed).exp()).relu()
|
return (x * (scale * speed).exp()).relu()
|
||||||
|
|
||||||
|
|
||||||
class ExpScaleReluFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
|
||||||
ctx.save_for_backward(x.detach(), scale.detach())
|
|
||||||
ctx.speed = speed
|
|
||||||
return _exp_scale_swish(x, scale, speed)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
||||||
x, scale = ctx.saved_tensors
|
|
||||||
x.requires_grad = True
|
|
||||||
scale.requires_grad = True
|
|
||||||
with torch.enable_grad():
|
|
||||||
y = _exp_scale_swish(x, scale, ctx.speed)
|
|
||||||
y.backward(gradient=y_grad)
|
|
||||||
return x.grad, scale.grad, None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpScaleReluFunction(torch.autograd.Function):
|
class ExpScaleReluFunction(torch.autograd.Function):
|
||||||
@ -563,16 +546,32 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
self.max_abs)
|
self.max_abs)
|
||||||
|
|
||||||
|
|
||||||
|
def _double_swish(x: Tensor) -> Tensor:
|
||||||
|
# double-swish, implemented/approximated as offset-swish
|
||||||
|
return x * torch.sigmoid(x - 1.0)
|
||||||
|
|
||||||
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
|
ctx.save_for_backward(x.detach())
|
||||||
|
return _double_swish(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
|
# TODO: can make this more efficient.
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
x.requires_grad = True
|
||||||
|
with torch.enable_grad():
|
||||||
|
y = _double_swish(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
return x.grad
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1)
|
|
||||||
"""
|
"""
|
||||||
x1 = x - 1.0
|
return DoubleSwishFunction.apply(x)
|
||||||
s = torch.sigmoid(x1)
|
|
||||||
return (x1 * s) + s # (x-1) * s + s == x * s
|
|
||||||
|
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
def _test_exp_scale_swish():
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user