Change BasicNorm by adding 1+eps denominator; fix to (unused) DoubleSwish, revert to old status.
This commit is contained in:
parent
cff350d8de
commit
049174722f
@ -491,8 +491,10 @@ class BasicNorm(torch.nn.Module):
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||
eps = eps.exp()
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
||||
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) /
|
||||
(1.0 + eps)
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
|
||||
@ -1330,9 +1332,8 @@ class MaxEig(torch.nn.Module):
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
double_swish(x) = x * (torch.sigmoid(x-1) + alpha)
|
||||
double_swish(x) = x * torch.sigmoid(x-1)
|
||||
|
||||
for e.g. alpha=-0.05 (user supplied).
|
||||
This is a definition, originally motivated by its close numerical
|
||||
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
||||
|
||||
@ -1356,18 +1357,8 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
s = torch.sigmoid(x - 1.0)
|
||||
y = x * s
|
||||
|
||||
alpha = -0.05
|
||||
beta = 0.05
|
||||
x_limit = 0.15
|
||||
|
||||
# another part of this formula is:
|
||||
# ... + 0.2 * x.clamp(min=-0.15, max=0.15)
|
||||
# the deriv of this is
|
||||
# beta * (x.abs() < x_limit).
|
||||
|
||||
if requires_grad:
|
||||
deriv = (y * (1 - s) + s) # ignores the alpha part.
|
||||
deriv = deriv + (x.abs() < x_limit) * beta
|
||||
deriv = (y * (1 - s) + s)
|
||||
|
||||
# notes on derivative of x * sigmoid(x - 1):
|
||||
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
||||
@ -1376,7 +1367,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.044
|
||||
ceil = 1.2 + beta
|
||||
ceil = 1.2
|
||||
d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv))
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
@ -1384,8 +1375,6 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
assert d_scaled.max() < 256.0
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
# on wolframalpha, do: (x * sigmoid(x-1) - 0.05 * x + 0.05 * min(0.15, max(-0.15, x)) + 0.025) from x=-3 to 2
|
||||
y = y + alpha * x + beta * x.clamp(min=-x_limit, max=x_limit) - 0.025
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
return y
|
||||
@ -1394,13 +1383,11 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
d, = ctx.saved_tensors
|
||||
# the same constants as used in forward pass.
|
||||
alpha = -0.05
|
||||
beta = 0.05
|
||||
floor = -0.043637
|
||||
ceil = 1.2 + beta
|
||||
ceil = 1.2
|
||||
|
||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||
return (y_grad * (d + alpha))
|
||||
return y_grad * d
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1412,7 +1399,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return x * (torch.sigmoid(x - 1.0) - 0.05) + 0.05 * x.clamp(min=-0.15, max=0.15)
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
@ -1741,8 +1728,6 @@ def _test_basic_norm():
|
||||
y_rms = (y ** 2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
assert y_rms > 0.5 * x_rms
|
||||
|
||||
|
||||
def _test_double_swish_deriv():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user