mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Make DoubleSwish more memory efficient
This commit is contained in:
parent
f351777e9c
commit
ae25688253
@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||
return (x * (scale * speed).exp()).relu()
|
||||
|
||||
|
||||
class ExpScaleReluFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||
ctx.save_for_backward(x.detach(), scale.detach())
|
||||
ctx.speed = speed
|
||||
return _exp_scale_swish(x, scale, speed)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
x, scale = ctx.saved_tensors
|
||||
x.requires_grad = True
|
||||
scale.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
y = _exp_scale_swish(x, scale, ctx.speed)
|
||||
y.backward(gradient=y_grad)
|
||||
return x.grad, scale.grad, None
|
||||
|
||||
|
||||
|
||||
class ExpScaleReluFunction(torch.autograd.Function):
|
||||
@ -563,16 +546,32 @@ class DerivBalancer(torch.nn.Module):
|
||||
self.max_abs)
|
||||
|
||||
|
||||
def _double_swish(x: Tensor) -> Tensor:
|
||||
# double-swish, implemented/approximated as offset-swish
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
ctx.save_for_backward(x.detach())
|
||||
return _double_swish(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
# TODO: can make this more efficient.
|
||||
x, = ctx.saved_tensors
|
||||
x.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
y = _double_swish(x)
|
||||
y.backward(gradient=y_grad)
|
||||
return x.grad
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient
|
||||
backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1)
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
x1 = x - 1.0
|
||||
s = torch.sigmoid(x1)
|
||||
return (x1 * s) + s # (x-1) * s + s == x * s
|
||||
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
def _test_exp_scale_swish():
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user