diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8dcfdddc7..ca8617e22 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,25 +713,6 @@ class GaussProjDrop(torch.nn.Module): x = (x_next * self.rand_scale + x_bypass) return x -class PseudoNormalizeFunction(torch.autograd.Function): - """ - Function object that is the identity function in the forward pass; and, in the - backward pass, removes the component of the derivative in the direction of x itself - (as if it had gone through some kind of normalization layer - """ - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tensor: - x, = ctx.saved_tensors - eps = 1.0e-20 - x_sumsq = (x**2).sum() + eps - grad_x_sum = (x_grad * x).sum() - return x_grad - x * (grad_x_sum / x_sumsq) - def _compute_correlation_loss(cov: Tensor, eps: float) -> Tensor: @@ -759,7 +740,6 @@ def _update_cov_stats(cov: Tensor, x: Tensor of features/activations, of shape (num_frames, num_channels) beta: The decay constant for the stats, e.g. 0.8. """ - x = PseudoNormalizeFunction.apply(x) new_cov = torch.matmul(x.t(), x) return cov * beta + new_cov * (1-beta) @@ -818,18 +798,19 @@ class DecorrelateFunction(torch.autograd.Function): # the computation. x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) - x_sqnorm = (x.detach() ** 2).sum(dim=1) - - x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov - x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales - x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) - # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half - # precision) - x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) - - x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 with torch.enable_grad(): + x_sqnorm = (x ** 2).sum(dim=1) + + x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov + x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales + x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) + # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half + # precision) + x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) + + x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 + scaled_x = x * x_factor.unsqueeze(-1) cov = _update_cov_stats(old_cov, scaled_x, ctx.beta) assert old_cov.dtype != torch.float16 @@ -1014,13 +995,6 @@ def _test_gauss_proj_drop(): m1.eval() m2.eval() -def _test_pseudo_normalize(): - x = torch.randn(3, 4) - x.requires_grad = True - y = PseudoNormalizeFunction.apply(x) - l = (y**2).sum() - l.backward() - assert (x.grad * x).sum().abs() < 0.1 def _test_decorrelate(): D = 384 @@ -1049,7 +1023,6 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_pseudo_normalize() _test_decorrelate() _test_gauss_proj_drop() _test_activation_balancer_sign()