diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 068c4f77e..fa8f629f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -738,6 +738,7 @@ class DecorrelateFunction(torch.autograd.Function): full_shape = x.shape x = x.reshape(-1, num_channels) x = x.detach() + old_cov = old_cov.detach() x.requires_grad = True x_grad = x_grad.reshape(-1, num_channels) @@ -804,8 +805,9 @@ class Decorrelate(torch.nn.Module): x = x.transpose(self.channel_dim, -1) x = x.reshape(-1, x.shape[-1]) - cov = torch.matmul(x.t(), x) / x.shape[0] - self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + cov = torch.matmul(x.t(), x) + with torch.no_grad(): + self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.step += 1 return ans # ans == x.