mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Add and test reset() function
This commit is contained in:
parent
a1dc020270
commit
bb32556f9e
@ -347,8 +347,9 @@ class Abel(Optimizer):
|
|||||||
factor_grads = _get_factor_grads(p, grad)
|
factor_grads = _get_factor_grads(p, grad)
|
||||||
|
|
||||||
if step < 10 or step % 10 == 1:
|
if step < 10 or step % 10 == 1:
|
||||||
|
# note, the first step number we'll see is 1.
|
||||||
# update the factorization this only every 10 steps, to save time.
|
# update the factorization this only every 10 steps, to save time.
|
||||||
if step % 1000 == 0:
|
if step % 1000 == 1:
|
||||||
# Periodically refresh the factorization; this is
|
# Periodically refresh the factorization; this is
|
||||||
# out of concern that after a large number of
|
# out of concern that after a large number of
|
||||||
# multiplications, roundoff could cause it to drift
|
# multiplications, roundoff could cause it to drift
|
||||||
@ -421,7 +422,14 @@ class Abel(Optimizer):
|
|||||||
exp_avg_sq values will be substantially too small and prevents any
|
exp_avg_sq values will be substantially too small and prevents any
|
||||||
too-fast updates.
|
too-fast updates.
|
||||||
"""
|
"""
|
||||||
pass
|
for s in self.state.values():
|
||||||
|
s["delta"].zero_() # zero out momentum
|
||||||
|
s["step"] = 0 # will cause state["factorization"] to be reinitialized
|
||||||
|
s["exp_avg_sq"].zero_()
|
||||||
|
if "factors_exp_avg_sq" in s:
|
||||||
|
for e in s["factors_exp_avg_sq"]:
|
||||||
|
e.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Eve(Optimizer):
|
class Eve(Optimizer):
|
||||||
@ -764,6 +772,10 @@ def _test_abel():
|
|||||||
start = timeit.default_timer()
|
start = timeit.default_timer()
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
scheduler.step_epoch()
|
scheduler.step_epoch()
|
||||||
|
if isinstance(optim, Abel) and epoch % 5 == 0: # this is every 100 minibatches.
|
||||||
|
optim.reset() # just test that calling reset() doesn't
|
||||||
|
# break things, you wouldn't call it every
|
||||||
|
# epoch like this.
|
||||||
for n, (x,y) in enumerate(train_pairs):
|
for n, (x,y) in enumerate(train_pairs):
|
||||||
y_out = m(x)
|
y_out = m(x)
|
||||||
loss = ((y_out - y)**2).mean() * 100.0
|
loss = ((y_out - y)**2).mean() * 100.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user