mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add and test reset() function
This commit is contained in:
parent
a1dc020270
commit
bb32556f9e
@ -347,8 +347,9 @@ class Abel(Optimizer):
|
||||
factor_grads = _get_factor_grads(p, grad)
|
||||
|
||||
if step < 10 or step % 10 == 1:
|
||||
# note, the first step number we'll see is 1.
|
||||
# update the factorization this only every 10 steps, to save time.
|
||||
if step % 1000 == 0:
|
||||
if step % 1000 == 1:
|
||||
# Periodically refresh the factorization; this is
|
||||
# out of concern that after a large number of
|
||||
# multiplications, roundoff could cause it to drift
|
||||
@ -421,7 +422,14 @@ class Abel(Optimizer):
|
||||
exp_avg_sq values will be substantially too small and prevents any
|
||||
too-fast updates.
|
||||
"""
|
||||
pass
|
||||
for s in self.state.values():
|
||||
s["delta"].zero_() # zero out momentum
|
||||
s["step"] = 0 # will cause state["factorization"] to be reinitialized
|
||||
s["exp_avg_sq"].zero_()
|
||||
if "factors_exp_avg_sq" in s:
|
||||
for e in s["factors_exp_avg_sq"]:
|
||||
e.zero_()
|
||||
|
||||
|
||||
|
||||
class Eve(Optimizer):
|
||||
@ -764,6 +772,10 @@ def _test_abel():
|
||||
start = timeit.default_timer()
|
||||
for epoch in range(150):
|
||||
scheduler.step_epoch()
|
||||
if isinstance(optim, Abel) and epoch % 5 == 0: # this is every 100 minibatches.
|
||||
optim.reset() # just test that calling reset() doesn't
|
||||
# break things, you wouldn't call it every
|
||||
# epoch like this.
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user