mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes for jit scripting and osmetic improvements
This commit is contained in:
parent
60d491eee6
commit
dadeb3feec
@ -442,8 +442,6 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
|
||||
store_output_for_backprop: bool) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
ctx.store_output_for_backprop = store_output_for_backprop
|
||||
ctx.channel_dim = channel_dim
|
||||
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
||||
@ -540,12 +538,7 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
channel_dim = self.channel_dim
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
bias = self.bias
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
|
||||
scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) *
|
||||
self.log_scale.exp())
|
||||
return x * scales
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user