Fix jit export for torch 1.6

This commit is contained in:
pkufool 2022-06-18 08:35:25 +08:00
parent 413ca2ef77
commit 88ed814197
2 changed files with 30 additions and 17 deletions

View File

@ -740,12 +740,12 @@ class ConformerEncoder(nn.Module):
assert not self.training
assert len(states) == 2
assert states[0].shape == (
len(self.layers),
self.num_layers,
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == len(self.layers)
assert states[1].size(0) == self.num_layers
output = src

View File

@ -52,7 +52,18 @@ class ActivationBalancerFunction(torch.autograd.Function):
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
sum_dims = []
# If torch version less than 1.7.0, `if` in List will cause
# torch.jit.frontend.NotSupportedError: comprehension ifs not
# supported yet
if torch.jit.is_scripting() and torch.__version__ < '1.7.0':
for d in range(x.ndim):
if d != channel_dim:
sum_dims.append(d)
else:
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
@ -214,8 +225,8 @@ class ScaledLinear(nn.Linear):
def get_bias(self):
if self.bias is None or self.bias_scale is None:
return None
return self.bias * self.bias_scale.exp()
else:
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
else:
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
@ -412,16 +424,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -461,7 +473,8 @@ class DoubleSwish(torch.nn.Module):
"""
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
else:
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):