From bb32556f9e318c9a48f950e7fb47e7fc6dca7606 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 May 2022 16:20:10 +0800 Subject: [PATCH] Add and test reset() function --- .../ASR/pruned_transducer_stateless4b/optim.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py index 9aaf356cf..840ef4a58 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py @@ -347,8 +347,9 @@ class Abel(Optimizer): factor_grads = _get_factor_grads(p, grad) if step < 10 or step % 10 == 1: + # note, the first step number we'll see is 1. # update the factorization this only every 10 steps, to save time. - if step % 1000 == 0: + if step % 1000 == 1: # Periodically refresh the factorization; this is # out of concern that after a large number of # multiplications, roundoff could cause it to drift @@ -421,7 +422,14 @@ class Abel(Optimizer): exp_avg_sq values will be substantially too small and prevents any too-fast updates. """ - pass + for s in self.state.values(): + s["delta"].zero_() # zero out momentum + s["step"] = 0 # will cause state["factorization"] to be reinitialized + s["exp_avg_sq"].zero_() + if "factors_exp_avg_sq" in s: + for e in s["factors_exp_avg_sq"]: + e.zero_() + class Eve(Optimizer): @@ -764,6 +772,10 @@ def _test_abel(): start = timeit.default_timer() for epoch in range(150): scheduler.step_epoch() + if isinstance(optim, Abel) and epoch % 5 == 0: # this is every 100 minibatches. + optim.reset() # just test that calling reset() doesn't + # break things, you wouldn't call it every + # epoch like this. for n, (x,y) in enumerate(train_pairs): y_out = m(x) loss = ((y_out - y)**2).mean() * 100.0