diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 6dd21f02d..7550c1b0e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -803,7 +803,6 @@ class DecorrelateFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: x, old_cov = ctx.saved_tensors - # Reshape x and x_grad to be (num_frames, num_channels) x = x.transpose(-1, ctx.channel_dim) x_grad = x_grad.transpose(-1, ctx.channel_dim) @@ -813,8 +812,22 @@ class DecorrelateFunction(torch.autograd.Function): x_grad = x_grad.reshape(-1, num_channels) x.requires_grad = True + # Now, normalize the contributions of frames/pixels x to the covariance, + # to have magnitudes proportional to the norm of the gradient on that + # frame; the goal is to exclude "don't-care" frames such as padding frames from + # the computation. + + x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) + x_sqnorm = (x.detach() ** 2).sum(dim=1) + x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov + x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5 + with torch.enable_grad(): - cov = _update_cov_stats(old_cov, x, ctx.beta) + scaled_x = x * x_factor.unsqueeze(-1) + cov = _update_cov_stats(old_cov, scaled_x, ctx.beta) + old_cov[:] = cov # update the stats outside! This is not really + # how backprop is supposed to work, but this input + # is not differentiable.. loss = _compute_correlation_loss(cov, ctx.eps) if random.random() < 0.01: @@ -823,18 +836,14 @@ class DecorrelateFunction(torch.autograd.Function): decorr_x_grad = x.grad - # Now, normalize the magnitudes of the rows of the new grad - # contribution, to have magnitudes equals to ctx.scale times - # `loss ** 0.5` times the magnitude of the original grad. - decorr_x_grad_sqnorm = (decorr_x_grad ** 2).sum(dim=1) - x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) - # loss.detach().clamp(min=0.0, max=1.0) is a factor that means once # the loss starts getting quite small (less than 1), we start using # smaller derivatives. + + decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) - scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-20)) ** 0.5 - decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1) + scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5 + decorr_x_grad = decorr_x_grad * scale x_grad = x_grad + decorr_x_grad @@ -902,19 +911,12 @@ class Decorrelate(torch.nn.Module): return x with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) - ans = DecorrelateFunction.apply(x, self.cov.clone(), + # the function updates self.cov in its backward pass (it needs the gradient + # norm, for frame weighting). + ans = DecorrelateFunction.apply(x, self.cov, self.scale, self.eps, self.beta, self.channel_dim) # == x. - - x = x.transpose(self.channel_dim, -1) - x = x.reshape(-1, x.shape[-1]) - cov = torch.matmul(x.t(), x) - with torch.no_grad(): - self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) - m = self.cov.max() - assert m == m - - return ans # ans == x. + return ans