diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 5ee4bab98..8bd50d185 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function): if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] + + # sum_dims = [d for d in range(x.ndim) if d != channel_dim] + # The above line is not torch scriptable for torch 1.6.0 + # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa + sum_dims = 0 + for d in range(x.ndim): + if d != channel_dim: + sum_dims += d + xgt0 = x > 0 proportion_positive = torch.mean( xgt0.to(x.dtype), dim=sum_dims, keepdim=True @@ -214,8 +222,8 @@ class ScaledLinear(nn.Linear): def get_bias(self): if self.bias is None or self.bias_scale is None: return None - - return self.bias * self.bias_scale.exp() + else: + return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear( @@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d): ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() + + self.bias_scale: Optional[nn.Parameter] # for torchscript + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d): bias_scale = self.bias_scale if bias is None or bias_scale is None: return None - return bias * bias_scale.exp() + else: + return bias * bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional @@ -331,7 +343,8 @@ class ScaledConv2d(nn.Conv2d): bias_scale = self.bias_scale if bias is None or bias_scale is None: return None - return bias * bias_scale.exp() + else: + return bias * bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional @@ -412,16 +425,16 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting(): return x - - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + else: + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -461,7 +474,8 @@ class DoubleSwish(torch.nn.Module): """ if torch.jit.is_scripting(): return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) + else: + return DoubleSwishFunction.apply(x) class ScaledEmbedding(nn.Module):