diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index fe8867291..5866ee517 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -572,7 +572,7 @@ class Whiten(nn.Module): if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x, self.num_groups) > self.whitening_limit: + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: