Use torch.no_grad() for stats

This commit is contained in:
Daniel Povey 2022-05-17 14:49:00 +08:00
parent 07d3369234
commit 9133b57808

View File

@ -221,6 +221,7 @@ class OrthogonalTransformation(nn.Module):
""" """
x = torch.matmul(x, self.weight.t()) x = torch.matmul(x, self.weight.t())
if self.step % 10 == 0 and self.train(): if self.step % 10 == 0 and self.train():
with torch.no_grad():
# store covariance after input transform. # store covariance after input transform.
# Update covariance stats every 10 batches (in training mode) # Update covariance stats every 10 batches (in training mode)
f = x.reshape(-1, x.shape[-1]) f = x.reshape(-1, x.shape[-1])