Bug fix RE float16

This commit is contained in:
Daniel Povey 2022-10-16 10:58:22 +08:00
parent fc728f2738
commit 1135669e93

View File

@ -572,7 +572,7 @@ class Whiten(nn.Module):
if hasattr(self, 'min_prob') and random.random() < 0.25:
# occasionally switch between min_prob and max_prob, based on whether
# we are above or below the threshold.
if _whitening_metric(x, self.num_groups) > self.whitening_limit:
if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit:
# there would be a change to the grad.
self.prob = self.max_prob
else: