From dadeb3feecae73d17491b295c1483d996f1229eb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 1 Jan 2023 14:35:51 +0800 Subject: [PATCH] Fixes for jit scripting and osmetic improvements --- .../ASR/pruned_transducer_stateless7/scaling.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 830fe497b..7ac7b0a00 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -442,8 +442,6 @@ class BasicNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int, store_output_for_backprop: bool) -> Tensor: - if channel_dim < 0: - channel_dim = channel_dim + x.ndim ctx.store_output_for_backprop = store_output_for_backprop ctx.channel_dim = channel_dim scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale @@ -466,7 +464,7 @@ class BasicNormFunction(torch.autograd.Function): eps.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - # recompute scales from x, epsand log_scale. + # recompute scales from x, eps and log_scale. scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale ans = x * scales ans.backward(gradient=ans_grad.to(torch.float32)) @@ -540,12 +538,7 @@ class BasicNorm(torch.nn.Module): if torch.jit.is_scripting(): channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) * + scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) * self.log_scale.exp()) return x * scales