diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py index 1da066e54..ec2bc9ceb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -221,11 +221,12 @@ class OrthogonalTransformation(nn.Module): """ x = torch.matmul(x, self.weight.t()) if self.step % 10 == 0 and self.train(): - # store covariance after input transform. - # Update covariance stats every 10 batches (in training mode) - f = x.reshape(-1, x.shape[-1]) - cov = torch.matmul(f.t(), f) # channel_dim by channel_dim - self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + with torch.no_grad(): + # store covariance after input transform. + # Update covariance stats every 10 batches (in training mode) + f = x.reshape(-1, x.shape[-1]) + cov = torch.matmul(f.t(), f) # channel_dim by channel_dim + self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.step += 1 return x