Symmetrize covariance

This commit is contained in:
Daniel Povey 2022-05-17 15:14:23 +08:00
parent 9133b57808
commit c923b5900e

View File

@ -259,5 +259,5 @@ class OrthogonalTransformation(nn.Module):
@torch.no_grad()
def get_transformation_out(self) -> Tensor:
# see also get_transformation() above for notes on this.
cov = self.feats_cov
cov = 0.5 * (self.feats_cov + self.feats_cov.t()) # make sure symmetric
return get_transformation(cov)