mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix comments
This commit is contained in:
parent
88400a4435
commit
56afb4aa3c
@ -152,7 +152,7 @@ class BasicNorm(torch.nn.Module):
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
@ -430,19 +430,18 @@ class ActivationBalancer(torch.nn.Module):
|
||||
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
elif torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
if torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -486,11 +485,10 @@ class DoubleSwish(torch.nn.Module):
|
||||
# https://github.com/pytorch/pytorch/releases/tag/v1.9.0
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
elif torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
if torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user