From eeb95ed50227c260e1cb1ee446f4a82aaa701f92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jun 2022 16:25:45 +0800 Subject: [PATCH] Fix issue with cov scale --- .../ASR/pruned_transducer_stateless2/scaling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8f1c6f2e7..05eef1554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -819,14 +819,15 @@ class DecorrelateFunction(torch.autograd.Function): x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) x_sqnorm = (x.detach() ** 2).sum(dim=1) - x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov + + x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half # precision) x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) - x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5 + x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 with torch.enable_grad(): scaled_x = x * x_factor.unsqueeze(-1) @@ -837,6 +838,8 @@ class DecorrelateFunction(torch.autograd.Function): # is not differentiable.. loss = _compute_correlation_loss(cov, ctx.eps) + #print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}") + if random.random() < 0.01: logging.info(f"Decorrelate: loss = {loss}") @@ -1025,7 +1028,7 @@ def _test_pseudo_normalize(): x = torch.randn(3, 4) x.requires_grad = True y = PseudoNormalizeFunction.apply(x) - l = y.sin().sum() + l = (y**2).sum() l.backward() assert (x.grad * x).sum().abs() < 0.1