Some code reworking and fixes, rationalizing how speedup is done and fix an issue affecting learning rate.

This commit is contained in:
Daniel Povey 2022-06-23 15:40:51 +08:00
parent c34344e98f
commit 683b8e1504

View File

@ -244,6 +244,10 @@ class NeutralGradient(Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
kwargs = {'device':p.device, 'dtype':p.dtype} kwargs = {'device':p.device, 'dtype':p.dtype}
# grad_times_step is an estimate of the importance of each parameter, for purposes
# of determining step_period period (by which we takes steps only periodically, to save
# time)
state["grad_times_step"] = torch.zeros((), **kwargs)
if p.numel() > 1: if p.numel() > 1:
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
@ -376,6 +380,10 @@ class NeutralGradient(Optimizer):
this_delta = grad / grad_rms this_delta = grad / grad_rms
alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale
delta.add_(this_delta, alpha=alpha) delta.add_(this_delta, alpha=alpha)
if step % 10 == 0: # this periodicity is just to save time
# divide by 1-beta1 because we want the actual step...
state["grad_times_step"].mul_(beta1).add_( (this_delta * grad).sum(), alpha=-alpha/n_cached_grads)
else: else:
# The full update. # The full update.
step_within_period = state["step_within_period"] step_within_period = state["step_within_period"]
@ -409,7 +417,7 @@ class NeutralGradient(Optimizer):
cur_grad = self._change_coordinates(cur_grad, state, forward=False) cur_grad = self._change_coordinates(cur_grad, state, forward=False)
if random.random() < 0.00005: if random.random() < 0.00005:
# This is only for debug. The logic below would not be valid for n_cache_grads>0, # This is only for debug. The logic below would not be valid for n_cached_grads>0,
# anyway we will delete this code at some point. # anyway we will delete this code at some point.
# in principle, the cur_grad is supposed to have the same rms as params, on average. # in principle, the cur_grad is supposed to have the same rms as params, on average.
cur_grad_rms = (cur_grad**2).mean().sqrt() cur_grad_rms = (cur_grad**2).mean().sqrt()
@ -432,11 +440,16 @@ class NeutralGradient(Optimizer):
if param_pow != 1.0 or grad_pow != 1.0: if param_pow != 1.0 or grad_pow != 1.0:
# Renormalize scale of cur_grad # Renormalize scale of cur_grad
scalar_exp_avg_sq = state["scalar_exp_avg_sq"] scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean(), alpha=(1-beta2)) scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean()/n_cached_grads, alpha=(1-beta2))
alpha = alpha * scale / ((scalar_exp_avg_sq / bias_correction2).sqrt() + grad_eps) alpha = alpha * scale / ((scalar_exp_avg_sq / bias_correction2).sqrt() + grad_eps)
delta.add_(cur_grad, alpha=alpha) delta.add_(cur_grad, alpha=alpha)
if step % 10 == 0: # this periodicity is just to save time
# divide by 1-beta1 because we want the actual step...
state["grad_times_step"].mul_(beta1).add_( (cur_grad * grad).sum(),
alpha=-alpha/n_cached_grads)
state["step_within_period"] += 1 state["step_within_period"] += 1
else: else:
# p.numel() == 1. # p.numel() == 1.
@ -449,6 +462,8 @@ class NeutralGradient(Optimizer):
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5) alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
delta.add_(this_delta, alpha=alpha) delta.add_(this_delta, alpha=alpha)
state["grad_times_step"].mul_(beta2).add_(this_delta * grad, alpha=-alpha*(1-beta2)/((1-beta1)*n_cached_grads))
if random.random() < 0.0001: if random.random() < 0.0001:
logging.info(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}") logging.info(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
p.add_(delta) p.add_(delta)
@ -894,17 +909,9 @@ class NeutralGradient(Optimizer):
state = self.state[p] state = self.state[p]
step = state["step"] step = state["step"]
prod = state["exp_avg_sq"].sqrt().sum() # the sqrt() is just a heuristic, it seems to give periods that work better.
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias param_prods.append(state["grad_times_step"].sqrt())
if is_one_axis or p.numel() == 1:
if is_one_axis:
prod *= state["scale"]
bias_correction2 = 1 - beta2 ** (step + 1)
else:
step_within_period = state["step_within_period"]
bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1)
prod /= bias_correction2 ** 0.5
param_prods.append(prod)
param_prods = torch.stack(param_prods).to('cpu') param_prods = torch.stack(param_prods).to('cpu')
# TEMP # TEMP
@ -1598,7 +1605,7 @@ def _test_eve_cain():
estimate_period=500, stats_steps=100) estimate_period=500, stats_steps=100)
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000, elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
estimate_period=500, stats_steps=100) estimate_period=500, stats_steps=100)
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer() start = timeit.default_timer()
avg_loss = 0.0 avg_loss = 0.0