diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index df69f67d8..6dd21f02d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -759,8 +759,8 @@ def _update_cov_stats(cov: Tensor, x: Tensor of features/activations, of shape (num_frames, num_channels) beta: The decay constant for the stats, e.g. 0.8. """ + x = PseudoNormalizeFunction.apply(x) new_cov = torch.matmul(x.t(), x) - new_cov = PseudoNormalizeFunction.apply(new_cov) return cov * beta + new_cov * (1-beta)