mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update optim.py (#1292)
This commit is contained in:
parent
ce08230ade
commit
fefffc02f6
@ -491,6 +491,12 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
if self.show_dominant_parameters:
|
if self.show_dominant_parameters:
|
||||||
assert p.shape[0] == len(param_names)
|
assert p.shape[0] == len(param_names)
|
||||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||||
|
if ans != ans: # e.g. ans is nan
|
||||||
|
ans = 0.0
|
||||||
|
if ans == 0.0:
|
||||||
|
for p, state, param_names in tuples:
|
||||||
|
p.grad.zero_() # get rid of infinity()
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(
|
||||||
@ -573,7 +579,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if clipping_scale != 1.0:
|
if clipping_scale != 1.0:
|
||||||
grad = grad * clipping_scale
|
grad *= clipping_scale
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user