minor updates

This commit is contained in:
marcoyang 2024-07-24 15:05:13 +08:00
parent f4912b283d
commit eb9800e121

View File

@ -786,17 +786,10 @@ class BalancerFunction(torch.autograd.Function):
loss_grad = loss_grad * (grad_scale / loss_grad_rms) loss_grad = loss_grad * (grad_scale / loss_grad_rms)
# if x_grad.dtype == torch.float16:
# x_grad_float = x_grad.to(torch.float32)
# else:
# x_grad_float = x_grad
# scale each element of loss_grad by the absolute value of the corresponding # scale each element of loss_grad by the absolute value of the corresponding
# element of x_grad, which we view as a noisy estimate of its magnitude for that # element of x_grad, which we view as a noisy estimate of its magnitude for that
# (frame and dimension). later we can consider factored versions. # (frame and dimension). later we can consider factored versions.
# x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad + (x_grad.abs() * loss_grad) x_grad = x_grad + (x_grad.abs() * loss_grad)
# x_grad = x_grad_mod.to(x_grad.dtype)
except Exception as e: except Exception as e:
logging.info( logging.info(
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."