From 8e56445c70b2da06ddae2afe4777b352e7e79120 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Jun 2022 20:07:35 +0800 Subject: [PATCH] Try to resolve graph-freed problem --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 068c4f77e..fa8f629f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -738,6 +738,7 @@ class DecorrelateFunction(torch.autograd.Function): full_shape = x.shape x = x.reshape(-1, num_channels) x = x.detach() + old_cov = old_cov.detach() x.requires_grad = True x_grad = x_grad.reshape(-1, num_channels) @@ -804,8 +805,9 @@ class Decorrelate(torch.nn.Module): x = x.transpose(self.channel_dim, -1) x = x.reshape(-1, x.shape[-1]) - cov = torch.matmul(x.t(), x) / x.shape[0] - self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + cov = torch.matmul(x.t(), x) + with torch.no_grad(): + self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.step += 1 return ans # ans == x.