Use torch.no_grad() for stats

This commit is contained in:
Daniel Povey 2022-05-17 14:49:00 +08:00
parent 07d3369234
commit 9133b57808

View File

@ -221,11 +221,12 @@ class OrthogonalTransformation(nn.Module):
""" """
x = torch.matmul(x, self.weight.t()) x = torch.matmul(x, self.weight.t())
if self.step % 10 == 0 and self.train(): if self.step % 10 == 0 and self.train():
# store covariance after input transform. with torch.no_grad():
# Update covariance stats every 10 batches (in training mode) # store covariance after input transform.
f = x.reshape(-1, x.shape[-1]) # Update covariance stats every 10 batches (in training mode)
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim f = x.reshape(-1, x.shape[-1])
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1 self.step += 1
return x return x