mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Do average computation in double precision
This commit is contained in:
parent
b2259184b5
commit
da2ffd4d27
@ -476,6 +476,10 @@ def average_tensor(
|
||||
t1.mul_(weight_1)
|
||||
t1.add_(t2, alpha=weight_2)
|
||||
else:
|
||||
# do this in double precision to reduce roundoff error.
|
||||
output = t1
|
||||
t1 = t1.to(torch.float64)
|
||||
t2 = t2.to(torch.float64)
|
||||
eps = 1.0e-05
|
||||
scale_1 = (t1 ** 2).mean().sqrt() + eps
|
||||
direction_1 = t1 / scale_1
|
||||
@ -485,7 +489,7 @@ def average_tensor(
|
||||
log_scale_2 = scale_2.log()
|
||||
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
|
||||
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
|
||||
t1.copy_(log_scale_1.exp() * direction_1)
|
||||
output.copy_((log_scale_1.exp() * direction_1).to(dtype=t1.dtype))
|
||||
|
||||
|
||||
def average_state_dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user