Scale covariance

This commit is contained in:
Daniel Povey 2022-05-17 15:32:06 +08:00
parent c923b5900e
commit ceb4eb4b85

View File

@ -225,6 +225,7 @@ class OrthogonalTransformation(nn.Module):
# store covariance after input transform.
# Update covariance stats every 10 batches (in training mode)
f = x.reshape(-1, x.shape[-1])
f = f * (f.shape[0] ** -0.5)
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1