Fix comments

This commit is contained in:
pkufool 2022-08-10 17:53:49 +08:00
parent 88400a4435
commit 56afb4aa3c

View File

@ -152,7 +152,7 @@ class BasicNorm(torch.nn.Module):
self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor:
if not torch.jit.is_tracing():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
@ -430,19 +430,18 @@ class ActivationBalancer(torch.nn.Module):
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
if torch.jit.is_scripting():
return x
elif torch.jit.is_tracing():
return x
else:
if torch.jit.is_tracing():
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -486,11 +485,10 @@ class DoubleSwish(torch.nn.Module):
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
elif torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
else:
if torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):