Add and test reset() function

This commit is contained in:
Daniel Povey 2022-05-15 16:20:10 +08:00
parent a1dc020270
commit bb32556f9e

View File

@ -347,8 +347,9 @@ class Abel(Optimizer):
factor_grads = _get_factor_grads(p, grad)
if step < 10 or step % 10 == 1:
# note, the first step number we'll see is 1.
# update the factorization this only every 10 steps, to save time.
if step % 1000 == 0:
if step % 1000 == 1:
# Periodically refresh the factorization; this is
# out of concern that after a large number of
# multiplications, roundoff could cause it to drift
@ -421,7 +422,14 @@ class Abel(Optimizer):
exp_avg_sq values will be substantially too small and prevents any
too-fast updates.
"""
pass
for s in self.state.values():
s["delta"].zero_() # zero out momentum
s["step"] = 0 # will cause state["factorization"] to be reinitialized
s["exp_avg_sq"].zero_()
if "factors_exp_avg_sq" in s:
for e in s["factors_exp_avg_sq"]:
e.zero_()
class Eve(Optimizer):
@ -764,6 +772,10 @@ def _test_abel():
start = timeit.default_timer()
for epoch in range(150):
scheduler.step_epoch()
if isinstance(optim, Abel) and epoch % 5 == 0: # this is every 100 minibatches.
optim.reset() # just test that calling reset() doesn't
# break things, you wouldn't call it every
# epoch like this.
for n, (x,y) in enumerate(train_pairs):
y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0