mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix torch.jit.script
This commit is contained in:
parent
f51e64dada
commit
88400a4435
@ -424,18 +424,25 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
|
||||
# Fixed: https://github.com/pytorch/pytorch/pull/49853
|
||||
# The fix was included in v1.9.0
|
||||
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
if torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -473,10 +480,17 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
|
||||
# Fixed: https://github.com/pytorch/pytorch/pull/49853
|
||||
# The fix was included in v1.9.0
|
||||
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
if torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user