mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Combine ExpScale and swish for memory reduction
This commit is contained in:
parent
23b3aa233c
commit
bc6c720e25
@ -206,3 +206,70 @@ class ExpScale(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x * (self.scale * self.speed).exp()
|
return x * (self.scale * self.speed).exp()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
|
return (x * torch.sigmoid(x)) * (scale * speed).exp()
|
||||||
|
|
||||||
|
|
||||||
|
def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
|
return (x * torch.sigmoid(x)) * (scale * speed).exp()
|
||||||
|
|
||||||
|
|
||||||
|
class ExpScaleSwishFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||||
|
ctx.save_for_backward(x, scale)
|
||||||
|
ctx.speed = speed
|
||||||
|
return _exp_scale_swish(x, scale, speed)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
|
x, scale = ctx.saved_tensors
|
||||||
|
x.requires_grad = True
|
||||||
|
scale.requires_grad = True
|
||||||
|
with torch.enable_grad():
|
||||||
|
y = _exp_scale_swish(x, scale, ctx.speed)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
return x.grad, scale.grad, None
|
||||||
|
|
||||||
|
|
||||||
|
class ExpScaleSwish(torch.nn.Module):
|
||||||
|
# combines ExpScale an Swish
|
||||||
|
# caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0)
|
||||||
|
def __init__(self, *shape, speed: float = 1.0):
|
||||||
|
super(ExpScaleSwish, self).__init__()
|
||||||
|
self.scale = nn.Parameter(torch.zeros(*shape))
|
||||||
|
self.speed = speed
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return ExpScaleSwishFunction.apply(x, self.scale, self.speed)
|
||||||
|
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp()
|
||||||
|
# return x * (self.scale * self.speed).exp()
|
||||||
|
|
||||||
|
def _test_exp_scale_swish():
|
||||||
|
class Swish(torch.nn.Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
x1 = torch.randn(50, 60).detach()
|
||||||
|
x2 = x1.detach()
|
||||||
|
|
||||||
|
m1 = ExpScaleSwish(50, 1, speed=4.0)
|
||||||
|
m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0))
|
||||||
|
x1.requires_grad = True
|
||||||
|
x2.requires_grad = True
|
||||||
|
|
||||||
|
y1 = m1(x1)
|
||||||
|
y2 = m2(x2)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
y1.sum().backward()
|
||||||
|
y2.sum().backward()
|
||||||
|
assert torch.allclose(x1.grad, x2.grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test_exp_scale_swish()
|
||||||
|
@ -156,8 +156,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
ExpScaleSwish(dim_feedforward, speed=4.0),
|
||||||
ExpScale(dim_feedforward, speed=4.0),
|
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
@ -165,7 +164,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
Swish(),
|
||||||
ExpScale(dim_feedforward, speed=4.0),
|
ExpScaleSwish(dim_feedforward, speed=4.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user