mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use Swoosh not DoubleSwish in zipformer; fix constants in Swoosh
This commit is contained in:
parent
ec10573edc
commit
14267a5194
@ -1216,9 +1216,9 @@ class TanSwish(torch.nn.Module):
|
||||
|
||||
class SwooshFunction(torch.autograd.Function):
|
||||
"""
|
||||
swoosh(x) = log(1 + exp(x-4)) - 0.055*x - 0.15
|
||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.15
|
||||
|
||||
derivatives are between -0.055 and 1-0.055.
|
||||
derivatives are between -0.08 and 0.92.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -1235,15 +1235,15 @@ class SwooshFunction(torch.autograd.Function):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
y = torch.logaddexp(one, x - 4) - 0.055 * x - 0.15
|
||||
y = torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
|
||||
|
||||
if not requires_grad:
|
||||
return y
|
||||
y.backward(gradient = torch.ones_like(y))
|
||||
|
||||
grad = x.grad
|
||||
floor = -0.055
|
||||
ceil = 0.946 # real ceil would be 0.0945, give it extra room for roundoff.
|
||||
floor = -0.08
|
||||
ceil = 0.925 # real ceil would be 0.092, give it extra room for roundoff.
|
||||
|
||||
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
||||
if __name__ == "__main__":
|
||||
@ -1261,8 +1261,8 @@ class SwooshFunction(torch.autograd.Function):
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
d, = ctx.saved_tensors
|
||||
# the same constants as used in forward pass.
|
||||
floor = -0.055
|
||||
ceil = 0.946
|
||||
floor = -0.08
|
||||
ceil = 0.925
|
||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||
return (y_grad * d)
|
||||
|
||||
@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module):
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(one, x - 4) - 0.055 * x - 0.15
|
||||
return torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
|
||||
return SwooshFunction.apply(x)
|
||||
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from scaling import (
|
||||
BasicNorm,
|
||||
MaxEig,
|
||||
DoubleSwish,
|
||||
Swoosh,
|
||||
TanSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
@ -1421,10 +1422,10 @@ class FeedforwardModule(nn.Module):
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
min_abs=1.5,
|
||||
max_abs=15.0,
|
||||
min_abs=2.0,
|
||||
max_abs=10.0,
|
||||
min_prob=0.25)
|
||||
self.activation = DoubleSwish()
|
||||
self.activation = Swoosh()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01,
|
||||
@ -1599,10 +1600,11 @@ class ConvolutionModule(nn.Module):
|
||||
channels, channel_dim=1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=10),
|
||||
min_abs=1.0,
|
||||
max_abs=ScheduledFloat((0.0, 10.0), (8000.0, 20.0), default=10),
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
self.activation = Swoosh()
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
|
||||
@ -547,7 +547,7 @@ def attach_diagnostics(
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]:
|
||||
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]:
|
||||
# For these specific module types, accumulate some additional diagnostics
|
||||
# that can help us improve the activation function. These require a lot of memory,
|
||||
# to save the forward activations, so limit this to some select classes.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user