Bug fix RE float16
This commit is contained in:
parent
fc728f2738
commit
1135669e93
@ -572,7 +572,7 @@ class Whiten(nn.Module):
|
||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
||||
# occasionally switch between min_prob and max_prob, based on whether
|
||||
# we are above or below the threshold.
|
||||
if _whitening_metric(x, self.num_groups) > self.whitening_limit:
|
||||
if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit:
|
||||
# there would be a change to the grad.
|
||||
self.prob = self.max_prob
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user