mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes for jit scripting and osmetic improvements
This commit is contained in:
parent
60d491eee6
commit
dadeb3feec
@ -442,8 +442,6 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
|
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
|
||||||
store_output_for_backprop: bool) -> Tensor:
|
store_output_for_backprop: bool) -> Tensor:
|
||||||
if channel_dim < 0:
|
|
||||||
channel_dim = channel_dim + x.ndim
|
|
||||||
ctx.store_output_for_backprop = store_output_for_backprop
|
ctx.store_output_for_backprop = store_output_for_backprop
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
||||||
@ -466,7 +464,7 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
eps.requires_grad = True
|
eps.requires_grad = True
|
||||||
scale.requires_grad = True
|
scale.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
# recompute scales from x, epsand log_scale.
|
# recompute scales from x, eps and log_scale.
|
||||||
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
|
||||||
ans = x * scales
|
ans = x * scales
|
||||||
ans.backward(gradient=ans_grad.to(torch.float32))
|
ans.backward(gradient=ans_grad.to(torch.float32))
|
||||||
@ -540,12 +538,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
channel_dim = self.channel_dim
|
channel_dim = self.channel_dim
|
||||||
if channel_dim < 0:
|
scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) *
|
||||||
channel_dim += x.ndim
|
|
||||||
bias = self.bias
|
|
||||||
for _ in range(channel_dim + 1, x.ndim):
|
|
||||||
bias = bias.unsqueeze(-1)
|
|
||||||
scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
|
|
||||||
self.log_scale.exp())
|
self.log_scale.exp())
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user