Move PseudoNormalizeFunction to a different place.

This commit is contained in:
Daniel Povey 2022-06-10 14:01:13 +08:00
parent 77357cdaa8
commit 58cbc3d961

View File

@ -759,8 +759,8 @@ def _update_cov_stats(cov: Tensor,
x: Tensor of features/activations, of shape (num_frames, num_channels)
beta: The decay constant for the stats, e.g. 0.8.
"""
x = PseudoNormalizeFunction.apply(x)
new_cov = torch.matmul(x.t(), x)
new_cov = PseudoNormalizeFunction.apply(new_cov)
return cov * beta + new_cov * (1-beta)