mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp467' into scaled_adam_exp472
This commit is contained in:
commit
edd4bf5312
@ -928,7 +928,6 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||||
return (y_grad * d)
|
return (y_grad * d)
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
@ -939,6 +938,68 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TanSwishFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
double_swish(x) = tan(x) * torch.sigmoid(x-1)
|
||||||
|
|
||||||
|
|
||||||
|
entering: d/dx(tanh(x) * sigmoid(x-1))
|
||||||
|
into wolfram alpha, I see that the range of this function is
|
||||||
|
-0.0498087 <= y <= 0.417894
|
||||||
|
let's make it (as we don't know how this was rounded):
|
||||||
|
-0.0498088 <= y <= 0.417895
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
|
requires_grad = x.requires_grad
|
||||||
|
if not requires_grad:
|
||||||
|
return torch.tanh(x) * torch.sigmoid(x - 1.0)
|
||||||
|
|
||||||
|
x_dtype = x.dtype
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
with torch.enable_grad():
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
y = torch.tanh(x) * torch.sigmoid(x - 1.0)
|
||||||
|
y.backward(gradient=torch.ones_like(y))
|
||||||
|
grad = x.grad
|
||||||
|
floor = -0.0498088
|
||||||
|
ceil = 0.417895
|
||||||
|
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# for self-testing only.
|
||||||
|
assert d_scaled.min() >= 0.0
|
||||||
|
assert d_scaled.max() < 256.0
|
||||||
|
|
||||||
|
d_int = d_scaled.to(torch.uint8)
|
||||||
|
ctx.save_for_backward(d_int)
|
||||||
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
|
y = y.to(torch.float16)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
|
d, = ctx.saved_tensors
|
||||||
|
# the same constants as used in forward pass.
|
||||||
|
floor = -0.0498088
|
||||||
|
ceil = 0.417895
|
||||||
|
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||||
|
return (y_grad * d)
|
||||||
|
|
||||||
|
|
||||||
|
class TanSwish(torch.nn.Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Return tan-swish activation function which is tanh(x) sigmoid(x-1)n
|
||||||
|
"""
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return x.tanh() * torch.sigmoid(x - 1.0)
|
||||||
|
return TanSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class ScheduledFloat(torch.nn.Module):
|
class ScheduledFloat(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||||
@ -1147,6 +1208,20 @@ def _test_double_swish_deriv():
|
|||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = m(x)
|
y = m(x)
|
||||||
|
|
||||||
|
def _test_tan_swish_deriv():
|
||||||
|
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
||||||
|
x.requires_grad = True
|
||||||
|
m = TanSwish()
|
||||||
|
|
||||||
|
tol = ((1.2-(-0.043637))/255.0)
|
||||||
|
torch.autograd.gradcheck(m, x, atol=tol)
|
||||||
|
|
||||||
|
|
||||||
|
# for self-test.
|
||||||
|
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
||||||
|
x.requires_grad = True
|
||||||
|
y = m(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_softmax():
|
def _test_softmax():
|
||||||
@ -1173,3 +1248,4 @@ if __name__ == "__main__":
|
|||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
_test_double_swish_deriv()
|
_test_double_swish_deriv()
|
||||||
|
_test_tan_swish_deriv()
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from scaling import (
|
|||||||
BasicNorm,
|
BasicNorm,
|
||||||
MaxEig,
|
MaxEig,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
|
TanSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
Whiten,
|
Whiten,
|
||||||
@ -1317,7 +1318,7 @@ class AttentionSqueeze(nn.Module):
|
|||||||
max_factor=0.02,
|
max_factor=0.02,
|
||||||
min_prob=0.1,
|
min_prob=0.1,
|
||||||
)
|
)
|
||||||
self.bottleneck_activation = nn.Tanh() # in bottleneck
|
self.bottleneck_activation = TanSwish() # in bottleneck
|
||||||
self.activation = Identity() # for diagnostics
|
self.activation = Identity() # for diagnostics
|
||||||
|
|
||||||
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user