mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change Swoosh formula so left crossing is near zero; change min_positive, max_positive of ActivationBalancer.
This commit is contained in:
parent
d5bfca4f49
commit
67812276ed
@ -1235,15 +1235,15 @@ class SwooshFunction(torch.autograd.Function):
|
||||
with torch.enable_grad():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
y = torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
||||
|
||||
if not requires_grad:
|
||||
return y
|
||||
y.backward(gradient = torch.ones_like(y))
|
||||
|
||||
grad = x.grad
|
||||
floor = -0.08
|
||||
ceil = 0.925 # real ceil would be 0.092, give it extra room for roundoff.
|
||||
floor = -0.1
|
||||
ceil = 0.905 # real ceil would be 0.09, give it extra room for roundoff.
|
||||
|
||||
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
|
||||
if __name__ == "__main__":
|
||||
@ -1261,8 +1261,8 @@ class SwooshFunction(torch.autograd.Function):
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
d, = ctx.saved_tensors
|
||||
# the same constants as used in forward pass.
|
||||
floor = -0.08
|
||||
ceil = 0.925
|
||||
floor = -0.1
|
||||
ceil = 0.905
|
||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||
return (y_grad * d)
|
||||
|
||||
@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module):
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
||||
return SwooshFunction.apply(x)
|
||||
|
||||
|
||||
|
||||
@ -1421,7 +1421,8 @@ class FeedforwardModule(nn.Module):
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=2.0,
|
||||
max_abs=10.0,
|
||||
min_prob=0.25)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user