mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Support torch 1.6.0 (#433)
This commit is contained in:
parent
5379c8e9fa
commit
d53f69108f
@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
|
||||
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
# The above line is not torch scriptable for torch 1.6.0
|
||||
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
|
||||
sum_dims = 0
|
||||
for d in range(x.ndim):
|
||||
if d != channel_dim:
|
||||
sum_dims += d
|
||||
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(
|
||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||
@ -214,8 +222,8 @@ class ScaledLinear(nn.Linear):
|
||||
def get_bias(self):
|
||||
if self.bias is None or self.bias_scale is None:
|
||||
return None
|
||||
|
||||
return self.bias * self.bias_scale.exp()
|
||||
else:
|
||||
return self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(
|
||||
@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d):
|
||||
):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
|
||||
self.bias_scale: Optional[nn.Parameter] # for torchscript
|
||||
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
else:
|
||||
return bias * bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
@ -331,7 +343,8 @@ class ScaledConv2d(nn.Conv2d):
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
else:
|
||||
return bias * bias_scale.exp()
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
@ -412,16 +425,16 @@ class ActivationBalancer(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -461,7 +474,8 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user