From 9133b57808edde787bf6fec9a05fcca0389011a2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 17 May 2022 14:49:00 +0800 Subject: [PATCH] Use torch.no_grad() for stats --- .../ASR/pruned_transducer_stateless4b/diagonalize.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py index 1da066e54..ec2bc9ceb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -221,11 +221,12 @@ class OrthogonalTransformation(nn.Module): """ x = torch.matmul(x, self.weight.t()) if self.step % 10 == 0 and self.train(): - # store covariance after input transform. - # Update covariance stats every 10 batches (in training mode) - f = x.reshape(-1, x.shape[-1]) - cov = torch.matmul(f.t(), f) # channel_dim by channel_dim - self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + with torch.no_grad(): + # store covariance after input transform. + # Update covariance stats every 10 batches (in training mode) + f = x.reshape(-1, x.shape[-1]) + cov = torch.matmul(f.t(), f) # channel_dim by channel_dim + self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.step += 1 return x