mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor updates
This commit is contained in:
parent
f4912b283d
commit
eb9800e121
@ -786,17 +786,10 @@ class BalancerFunction(torch.autograd.Function):
|
||||
|
||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||
|
||||
# if x_grad.dtype == torch.float16:
|
||||
# x_grad_float = x_grad.to(torch.float32)
|
||||
# else:
|
||||
# x_grad_float = x_grad
|
||||
|
||||
# scale each element of loss_grad by the absolute value of the corresponding
|
||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||
# (frame and dimension). later we can consider factored versions.
|
||||
# x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||
x_grad = x_grad + (x_grad.abs() * loss_grad)
|
||||
# x_grad = x_grad_mod.to(x_grad.dtype)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
|
||||
|
Loading…
x
Reference in New Issue
Block a user