Support torch 1.6.0 (#433)

This commit is contained in:
Fangjun Kuang 2022-06-17 22:24:47 +08:00 committed by GitHub
parent 5379c8e9fa
commit d53f69108f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function):
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# The above line is not torch scriptable for torch 1.6.0
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
sum_dims = 0
for d in range(x.ndim):
if d != channel_dim:
sum_dims += d
xgt0 = x > 0
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
@ -214,8 +222,8 @@ class ScaledLinear(nn.Linear):
def get_bias(self):
if self.bias is None or self.bias_scale is None:
return None
return self.bias * self.bias_scale.exp()
else:
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d):
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.bias_scale: Optional[nn.Parameter] # for torchscript
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
else:
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
@ -331,7 +343,8 @@ class ScaledConv2d(nn.Conv2d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
else:
return bias * bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
@ -412,16 +425,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -461,7 +474,8 @@ class DoubleSwish(torch.nn.Module):
"""
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
else:
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):