mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Scale covariance
This commit is contained in:
parent
c923b5900e
commit
ceb4eb4b85
@ -225,6 +225,7 @@ class OrthogonalTransformation(nn.Module):
|
|||||||
# store covariance after input transform.
|
# store covariance after input transform.
|
||||||
# Update covariance stats every 10 batches (in training mode)
|
# Update covariance stats every 10 batches (in training mode)
|
||||||
f = x.reshape(-1, x.shape[-1])
|
f = x.reshape(-1, x.shape[-1])
|
||||||
|
f = f * (f.shape[0] ** -0.5)
|
||||||
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
|
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
|
||||||
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user