mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use Swoosh-R in the Conv and Swoosh-L in the feedforward.
This commit is contained in:
parent
d214e1c352
commit
7b1f093077
@ -1214,7 +1214,7 @@ class TanSwish(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SwooshFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||||
|
|
||||||
@ -1267,14 +1267,77 @@ class SwooshFunction(torch.autograd.Function):
|
|||||||
return (y_grad * d)
|
return (y_grad * d)
|
||||||
|
|
||||||
|
|
||||||
class Swoosh(torch.nn.Module):
|
class SwooshL(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return tan-swish activation function which is tanh(x) sigmoid(x-1)n
|
"""Return Swoosh-L activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
||||||
return SwooshFunction.apply(x)
|
return SwooshLFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SwooshRFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
||||||
|
|
||||||
|
derivatives are between -0.08 and 0.92.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
|
requires_grad = x.requires_grad
|
||||||
|
x_dtype = x.dtype
|
||||||
|
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
with torch.enable_grad():
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
|
|
||||||
|
if not requires_grad:
|
||||||
|
return y
|
||||||
|
y.backward(gradient = torch.ones_like(y))
|
||||||
|
|
||||||
|
grad = x.grad
|
||||||
|
floor = -0.08
|
||||||
|
ceil = 0.925
|
||||||
|
|
||||||
|
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# for self-testing only.
|
||||||
|
assert d_scaled.min() >= 0.0
|
||||||
|
assert d_scaled.max() < 256.0
|
||||||
|
|
||||||
|
d_int = d_scaled.to(torch.uint8)
|
||||||
|
ctx.save_for_backward(d_int)
|
||||||
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
|
y = y.to(torch.float16)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
|
d, = ctx.saved_tensors
|
||||||
|
# the same constants as used in forward pass.
|
||||||
|
floor = -0.08
|
||||||
|
ceil = 0.925
|
||||||
|
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||||
|
return (y_grad * d)
|
||||||
|
|
||||||
|
|
||||||
|
class SwooshR(torch.nn.Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Return Swoosh-L activation.
|
||||||
|
"""
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
|
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
|
return SwooshRFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1434,10 +1497,23 @@ def _test_tan_swish_deriv():
|
|||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = m(x)
|
y = m(x)
|
||||||
|
|
||||||
def _test_swoosh_deriv():
|
def _test_swooshl_deriv():
|
||||||
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = Swoosh()
|
m = SwooshL()
|
||||||
|
|
||||||
|
tol = (1.0 / 255.0)
|
||||||
|
torch.autograd.gradcheck(m, x, atol=tol)
|
||||||
|
|
||||||
|
# for self-test.
|
||||||
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
||||||
|
x.requires_grad = True
|
||||||
|
y = m(x)
|
||||||
|
|
||||||
|
def _test_swooshr_deriv():
|
||||||
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
||||||
|
x.requires_grad = True
|
||||||
|
m = SwooshR()
|
||||||
|
|
||||||
tol = (1.0 / 255.0)
|
tol = (1.0 / 255.0)
|
||||||
torch.autograd.gradcheck(m, x, atol=tol)
|
torch.autograd.gradcheck(m, x, atol=tol)
|
||||||
@ -1474,4 +1550,5 @@ if __name__ == "__main__":
|
|||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
_test_double_swish_deriv()
|
_test_double_swish_deriv()
|
||||||
_test_tan_swish_deriv()
|
_test_tan_swish_deriv()
|
||||||
_test_swoosh_deriv()
|
_test_swooshr_deriv()
|
||||||
|
_test_swooshl_deriv()
|
||||||
|
|||||||
@ -29,7 +29,8 @@ from scaling import (
|
|||||||
BasicNorm,
|
BasicNorm,
|
||||||
MaxEig,
|
MaxEig,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
Swoosh,
|
SwooshL,
|
||||||
|
SwooshR,
|
||||||
TanSwish,
|
TanSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
@ -1426,7 +1427,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
min_abs=1.0,
|
min_abs=1.0,
|
||||||
max_abs=5.0,
|
max_abs=5.0,
|
||||||
min_prob=0.25)
|
min_prob=0.25)
|
||||||
self.activation = Swoosh()
|
self.activation = SwooshL()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01,
|
initial_scale=0.01,
|
||||||
@ -1601,11 +1602,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels, channel_dim=1,
|
channels, channel_dim=1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=0.75,
|
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
|
||||||
max_abs=10.0,
|
max_abs=10.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation = nn.Tanh()
|
self.activation = SwooshR()
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user