Update optim.py

This commit is contained in:
jinzr 2023-10-07 10:05:53 +08:00
parent 1b565dd251
commit 542698789a

View File

@ -491,6 +491,12 @@ class ScaledAdam(BatchedOptimizer):
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
if ans != ans: # e.g. ans is nan
ans = 0.0
if ans == 0.0:
for p, state, param_names in tuples:
p.grad.zero_() # get rid of infinity()
return ans
def _show_gradient_dominating_parameter(
@ -573,7 +579,7 @@ class ScaledAdam(BatchedOptimizer):
grad = p.grad
if clipping_scale != 1.0:
grad = grad * clipping_scale
grad *= clipping_scale
step = state["step"]
delta = state["delta"]