Support torch 1.6.0 (#433)

This commit is contained in:
Fangjun Kuang 2022-06-17 22:24:47 +08:00 committed by GitHub
parent 5379c8e9fa
commit d53f69108f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function):
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# The above line is not torch scriptable for torch 1.6.0
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
sum_dims = 0
for d in range(x.ndim):
if d != channel_dim:
sum_dims += d
xgt0 = x > 0 xgt0 = x > 0
proportion_positive = torch.mean( proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True xgt0.to(x.dtype), dim=sum_dims, keepdim=True
@ -214,8 +222,8 @@ class ScaledLinear(nn.Linear):
def get_bias(self): def get_bias(self):
if self.bias is None or self.bias_scale is None: if self.bias is None or self.bias_scale is None:
return None return None
else:
return self.bias * self.bias_scale.exp() return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear( return torch.nn.functional.linear(
@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d):
): ):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.bias_scale: Optional[nn.Parameter] # for torchscript
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
bias_scale = self.bias_scale bias_scale = self.bias_scale
if bias is None or bias_scale is None: if bias is None or bias_scale is None:
return None return None
return bias * bias_scale.exp() else:
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional F = torch.nn.functional
@ -331,7 +343,8 @@ class ScaledConv2d(nn.Conv2d):
bias_scale = self.bias_scale bias_scale = self.bias_scale
if bias is None or bias_scale is None: if bias is None or bias_scale is None:
return None return None
return bias * bias_scale.exp() else:
return bias * bias_scale.exp()
def _conv_forward(self, input, weight): def _conv_forward(self, input, weight):
F = torch.nn.functional F = torch.nn.functional
@ -412,16 +425,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return x return x
else:
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
self.channel_dim, self.channel_dim,
self.min_positive, self.min_positive,
self.max_positive, self.max_positive,
self.max_factor, self.max_factor,
self.min_abs, self.min_abs,
self.max_abs, self.max_abs,
) )
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
@ -461,7 +474,8 @@ class DoubleSwish(torch.nn.Module):
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0) return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x) else:
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module): class ScaledEmbedding(nn.Module):