Make DoubleSwish more memory efficient

This commit is contained in:
Daniel Povey 2022-03-14 11:02:32 +08:00
parent f351777e9c
commit ae25688253

View File

@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:
return (x * (scale * speed).exp()).relu() return (x * (scale * speed).exp()).relu()
class ExpScaleReluFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
ctx.save_for_backward(x.detach(), scale.detach())
ctx.speed = speed
return _exp_scale_swish(x, scale, speed)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
x, scale = ctx.saved_tensors
x.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
y = _exp_scale_swish(x, scale, ctx.speed)
y.backward(gradient=y_grad)
return x.grad, scale.grad, None
class ExpScaleReluFunction(torch.autograd.Function): class ExpScaleReluFunction(torch.autograd.Function):
@ -563,16 +546,32 @@ class DerivBalancer(torch.nn.Module):
self.max_abs) self.max_abs)
def _double_swish(x: Tensor) -> Tensor:
# double-swish, implemented/approximated as offset-swish
return x * torch.sigmoid(x - 1.0)
class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(x.detach())
return _double_swish(x)
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
# TODO: can make this more efficient.
x, = ctx.saved_tensors
x.requires_grad = True
with torch.enable_grad():
y = _double_swish(x)
y.backward(gradient=y_grad)
return x.grad
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient that we approximate closely with x * sigmoid(x-1).
backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1)
""" """
x1 = x - 1.0 return DoubleSwishFunction.apply(x)
s = torch.sigmoid(x1)
return (x1 * s) + s # (x-1) * s + s == x * s
def _test_exp_scale_swish(): def _test_exp_scale_swish():