Fix clamping of epsilon

This commit is contained in:
Daniel Povey 2022-10-28 12:50:14 +08:00
parent 7b8a0108ea
commit a067fe8026

View File

@ -382,8 +382,7 @@ class BasicNorm(torch.nn.Module):
# region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
) ** -0.5
return x * scales