Do average computation in double precision
This commit is contained in:
parent
b2259184b5
commit
da2ffd4d27
@ -476,6 +476,10 @@ def average_tensor(
|
|||||||
t1.mul_(weight_1)
|
t1.mul_(weight_1)
|
||||||
t1.add_(t2, alpha=weight_2)
|
t1.add_(t2, alpha=weight_2)
|
||||||
else:
|
else:
|
||||||
|
# do this in double precision to reduce roundoff error.
|
||||||
|
output = t1
|
||||||
|
t1 = t1.to(torch.float64)
|
||||||
|
t2 = t2.to(torch.float64)
|
||||||
eps = 1.0e-05
|
eps = 1.0e-05
|
||||||
scale_1 = (t1 ** 2).mean().sqrt() + eps
|
scale_1 = (t1 ** 2).mean().sqrt() + eps
|
||||||
direction_1 = t1 / scale_1
|
direction_1 = t1 / scale_1
|
||||||
@ -485,7 +489,7 @@ def average_tensor(
|
|||||||
log_scale_2 = scale_2.log()
|
log_scale_2 = scale_2.log()
|
||||||
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
|
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
|
||||||
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
|
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
|
||||||
t1.copy_(log_scale_1.exp() * direction_1)
|
output.copy_((log_scale_1.exp() * direction_1).to(dtype=t1.dtype))
|
||||||
|
|
||||||
|
|
||||||
def average_state_dict(
|
def average_state_dict(
|
||||||
|
|||||||
Reference in New Issue
Block a user