Fixes for jit scripting and osmetic improvements

This commit is contained in:
Daniel Povey 2023-01-01 14:35:51 +08:00
parent 60d491eee6
commit dadeb3feec

View File

@ -442,8 +442,6 @@ class BasicNormFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int, def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor: store_output_for_backprop: bool) -> Tensor:
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
@ -466,7 +464,7 @@ class BasicNormFunction(torch.autograd.Function):
eps.requires_grad = True eps.requires_grad = True
scale.requires_grad = True scale.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
# recompute scales from x, epsand log_scale. # recompute scales from x, eps and log_scale.
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
ans = x * scales ans = x * scales
ans.backward(gradient=ans_grad.to(torch.float32)) ans.backward(gradient=ans_grad.to(torch.float32))
@ -540,12 +538,7 @@ class BasicNorm(torch.nn.Module):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
channel_dim = self.channel_dim channel_dim = self.channel_dim
if channel_dim < 0: scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) *
channel_dim += x.ndim
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
self.log_scale.exp()) self.log_scale.exp())
return x * scales return x * scales