Use torch.no_grad() for stats

This commit is contained in:
Daniel Povey 2022-05-17 14:49:00 +08:00
parent 07d3369234
commit 9133b57808

View File

@ -221,11 +221,12 @@ class OrthogonalTransformation(nn.Module):
"""
x = torch.matmul(x, self.weight.t())
if self.step % 10 == 0 and self.train():
# store covariance after input transform.
# Update covariance stats every 10 batches (in training mode)
f = x.reshape(-1, x.shape[-1])
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
with torch.no_grad():
# store covariance after input transform.
# Update covariance stats every 10 batches (in training mode)
f = x.reshape(-1, x.shape[-1])
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return x