From 683b8e15048c79c52be4c7531b8786fbe0b1075d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 23 Jun 2022 15:40:51 +0800 Subject: [PATCH] Some code reworking and fixes, rationalizing how speedup is done and fix an issue affecting learning rate. --- .../ASR/pruned_transducer_stateless7/optim.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index d8e28b6b1..8b46199cb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -244,6 +244,10 @@ class NeutralGradient(Optimizer): state["exp_avg_sq"] = torch.zeros_like(p) kwargs = {'device':p.device, 'dtype':p.dtype} + # grad_times_step is an estimate of the importance of each parameter, for purposes + # of determining step_period period (by which we takes steps only periodically, to save + # time) + state["grad_times_step"] = torch.zeros((), **kwargs) if p.numel() > 1: is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias @@ -376,6 +380,10 @@ class NeutralGradient(Optimizer): this_delta = grad / grad_rms alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale delta.add_(this_delta, alpha=alpha) + + if step % 10 == 0: # this periodicity is just to save time + # divide by 1-beta1 because we want the actual step... + state["grad_times_step"].mul_(beta1).add_( (this_delta * grad).sum(), alpha=-alpha/n_cached_grads) else: # The full update. step_within_period = state["step_within_period"] @@ -409,7 +417,7 @@ class NeutralGradient(Optimizer): cur_grad = self._change_coordinates(cur_grad, state, forward=False) if random.random() < 0.00005: - # This is only for debug. The logic below would not be valid for n_cache_grads>0, + # This is only for debug. The logic below would not be valid for n_cached_grads>0, # anyway we will delete this code at some point. # in principle, the cur_grad is supposed to have the same rms as params, on average. cur_grad_rms = (cur_grad**2).mean().sqrt() @@ -432,11 +440,16 @@ class NeutralGradient(Optimizer): if param_pow != 1.0 or grad_pow != 1.0: # Renormalize scale of cur_grad scalar_exp_avg_sq = state["scalar_exp_avg_sq"] - scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean(), alpha=(1-beta2)) + scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean()/n_cached_grads, alpha=(1-beta2)) alpha = alpha * scale / ((scalar_exp_avg_sq / bias_correction2).sqrt() + grad_eps) delta.add_(cur_grad, alpha=alpha) + if step % 10 == 0: # this periodicity is just to save time + # divide by 1-beta1 because we want the actual step... + state["grad_times_step"].mul_(beta1).add_( (cur_grad * grad).sum(), + alpha=-alpha/n_cached_grads) + state["step_within_period"] += 1 else: # p.numel() == 1. @@ -449,6 +462,8 @@ class NeutralGradient(Optimizer): alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5) delta.add_(this_delta, alpha=alpha) + state["grad_times_step"].mul_(beta2).add_(this_delta * grad, alpha=-alpha*(1-beta2)/((1-beta1)*n_cached_grads)) + if random.random() < 0.0001: logging.info(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}") p.add_(delta) @@ -894,17 +909,9 @@ class NeutralGradient(Optimizer): state = self.state[p] step = state["step"] - prod = state["exp_avg_sq"].sqrt().sum() - is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias - if is_one_axis or p.numel() == 1: - if is_one_axis: - prod *= state["scale"] - bias_correction2 = 1 - beta2 ** (step + 1) - else: - step_within_period = state["step_within_period"] - bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1) - prod /= bias_correction2 ** 0.5 - param_prods.append(prod) + # the sqrt() is just a heuristic, it seems to give periods that work better. + param_prods.append(state["grad_times_step"].sqrt()) + param_prods = torch.stack(param_prods).to('cpu') # TEMP @@ -1598,7 +1605,7 @@ def _test_eve_cain(): estimate_period=500, stats_steps=100) elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000, estimate_period=500, stats_steps=100) - scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() avg_loss = 0.0