Fix torch.jit.script

This commit is contained in:
pkufool 2022-08-10 16:49:35 +08:00
parent f51e64dada
commit 88400a4435

View File

@ -424,18 +424,25 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853
# The fix was included in v1.9.0
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
if torch.jit.is_scripting():
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
if torch.jit.is_tracing():
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -473,10 +480,17 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853
# The fix was included in v1.9.0
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)
if torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):