Fix jit export for torch 1.6
This commit is contained in:
parent
413ca2ef77
commit
88ed814197
@ -740,12 +740,12 @@ class ConformerEncoder(nn.Module):
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (
|
||||
len(self.layers),
|
||||
self.num_layers,
|
||||
left_context,
|
||||
src.size(1),
|
||||
src.size(2),
|
||||
)
|
||||
assert states[1].size(0) == len(self.layers)
|
||||
assert states[1].size(0) == self.num_layers
|
||||
|
||||
output = src
|
||||
|
||||
|
||||
@ -52,7 +52,18 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
|
||||
sum_dims = []
|
||||
# If torch version less than 1.7.0, `if` in List will cause
|
||||
# torch.jit.frontend.NotSupportedError: comprehension ifs not
|
||||
# supported yet
|
||||
if torch.jit.is_scripting() and torch.__version__ < '1.7.0':
|
||||
for d in range(x.ndim):
|
||||
if d != channel_dim:
|
||||
sum_dims.append(d)
|
||||
else:
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(
|
||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||
@ -214,8 +225,8 @@ class ScaledLinear(nn.Linear):
|
||||
def get_bias(self):
|
||||
if self.bias is None or self.bias_scale is None:
|
||||
return None
|
||||
|
||||
return self.bias * self.bias_scale.exp()
|
||||
else:
|
||||
return self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(
|
||||
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
else:
|
||||
return bias * bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
@ -412,16 +424,16 @@ class ActivationBalancer(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -461,7 +473,8 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user