Do average computation in double precision

This commit is contained in:
Daniel Povey 2022-05-31 14:39:21 +08:00
parent b2259184b5
commit da2ffd4d27

View File

@ -476,6 +476,10 @@ def average_tensor(
t1.mul_(weight_1)
t1.add_(t2, alpha=weight_2)
else:
# do this in double precision to reduce roundoff error.
output = t1
t1 = t1.to(torch.float64)
t2 = t2.to(torch.float64)
eps = 1.0e-05
scale_1 = (t1 ** 2).mean().sqrt() + eps
direction_1 = t1 / scale_1
@ -485,7 +489,7 @@ def average_tensor(
log_scale_2 = scale_2.log()
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
t1.copy_(log_scale_1.exp() * direction_1)
output.copy_((log_scale_1.exp() * direction_1).to(dtype=t1.dtype))
def average_state_dict(