Fix an error introduced by supporting torchscript for torch 1.6.0 (#434)

This commit is contained in:
Fangjun Kuang 2022-06-18 08:57:20 +08:00 committed by GitHub
parent d53f69108f
commit ab788980c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -56,10 +56,10 @@ class ActivationBalancerFunction(torch.autograd.Function):
# sum_dims = [d for d in range(x.ndim) if d != channel_dim] # sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# The above line is not torch scriptable for torch 1.6.0 # The above line is not torch scriptable for torch 1.6.0
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
sum_dims = 0 sum_dims = []
for d in range(x.ndim): for d in range(x.ndim):
if d != channel_dim: if d != channel_dim:
sum_dims += d sum_dims.append(d)
xgt0 = x > 0 xgt0 = x > 0
proportion_positive = torch.mean( proportion_positive = torch.mean(