Support torch 1.6.0 (#433)

This commit is contained in:
Fangjun Kuang 2022-06-17 22:24:47 +08:00 committed by GitHub
parent 5379c8e9fa
commit d53f69108f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function):
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# The above line is not torch scriptable for torch 1.6.0
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
sum_dims = 0
for d in range(x.ndim):
if d != channel_dim:
sum_dims += d
xgt0 = x > 0
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
@ -214,7 +222,7 @@ class ScaledLinear(nn.Linear):
def get_bias(self):
if self.bias is None or self.bias_scale is None:
return None
else:
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d):
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.bias_scale: Optional[nn.Parameter] # for torchscript
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -262,6 +273,7 @@ class ScaledConv1d(nn.Conv1d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
else:
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
@ -331,6 +343,7 @@ class ScaledConv2d(nn.Conv2d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
else:
return bias * bias_scale.exp()
def _conv_forward(self, input, weight):
@ -412,7 +425,7 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
@ -461,6 +474,7 @@ class DoubleSwish(torch.nn.Module):
"""
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)