mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Some code reworking and fixes, rationalizing how speedup is done and fix an issue affecting learning rate.
This commit is contained in:
parent
c34344e98f
commit
683b8e1504
@ -244,6 +244,10 @@ class NeutralGradient(Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
# grad_times_step is an estimate of the importance of each parameter, for purposes
|
||||
# of determining step_period period (by which we takes steps only periodically, to save
|
||||
# time)
|
||||
state["grad_times_step"] = torch.zeros((), **kwargs)
|
||||
|
||||
if p.numel() > 1:
|
||||
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
||||
@ -376,6 +380,10 @@ class NeutralGradient(Optimizer):
|
||||
this_delta = grad / grad_rms
|
||||
alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale
|
||||
delta.add_(this_delta, alpha=alpha)
|
||||
|
||||
if step % 10 == 0: # this periodicity is just to save time
|
||||
# divide by 1-beta1 because we want the actual step...
|
||||
state["grad_times_step"].mul_(beta1).add_( (this_delta * grad).sum(), alpha=-alpha/n_cached_grads)
|
||||
else:
|
||||
# The full update.
|
||||
step_within_period = state["step_within_period"]
|
||||
@ -409,7 +417,7 @@ class NeutralGradient(Optimizer):
|
||||
cur_grad = self._change_coordinates(cur_grad, state, forward=False)
|
||||
|
||||
if random.random() < 0.00005:
|
||||
# This is only for debug. The logic below would not be valid for n_cache_grads>0,
|
||||
# This is only for debug. The logic below would not be valid for n_cached_grads>0,
|
||||
# anyway we will delete this code at some point.
|
||||
# in principle, the cur_grad is supposed to have the same rms as params, on average.
|
||||
cur_grad_rms = (cur_grad**2).mean().sqrt()
|
||||
@ -432,11 +440,16 @@ class NeutralGradient(Optimizer):
|
||||
if param_pow != 1.0 or grad_pow != 1.0:
|
||||
# Renormalize scale of cur_grad
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean(), alpha=(1-beta2))
|
||||
scalar_exp_avg_sq.mul_(beta2).add_((cur_grad**2).mean()/n_cached_grads, alpha=(1-beta2))
|
||||
alpha = alpha * scale / ((scalar_exp_avg_sq / bias_correction2).sqrt() + grad_eps)
|
||||
|
||||
delta.add_(cur_grad, alpha=alpha)
|
||||
|
||||
if step % 10 == 0: # this periodicity is just to save time
|
||||
# divide by 1-beta1 because we want the actual step...
|
||||
state["grad_times_step"].mul_(beta1).add_( (cur_grad * grad).sum(),
|
||||
alpha=-alpha/n_cached_grads)
|
||||
|
||||
state["step_within_period"] += 1
|
||||
else:
|
||||
# p.numel() == 1.
|
||||
@ -449,6 +462,8 @@ class NeutralGradient(Optimizer):
|
||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||
delta.add_(this_delta, alpha=alpha)
|
||||
|
||||
state["grad_times_step"].mul_(beta2).add_(this_delta * grad, alpha=-alpha*(1-beta2)/((1-beta1)*n_cached_grads))
|
||||
|
||||
if random.random() < 0.0001:
|
||||
logging.info(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
||||
p.add_(delta)
|
||||
@ -894,17 +909,9 @@ class NeutralGradient(Optimizer):
|
||||
state = self.state[p]
|
||||
step = state["step"]
|
||||
|
||||
prod = state["exp_avg_sq"].sqrt().sum()
|
||||
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
||||
if is_one_axis or p.numel() == 1:
|
||||
if is_one_axis:
|
||||
prod *= state["scale"]
|
||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
else:
|
||||
step_within_period = state["step_within_period"]
|
||||
bias_correction2 = 1 - beta2 ** (min(step, step_within_period) + 1)
|
||||
prod /= bias_correction2 ** 0.5
|
||||
param_prods.append(prod)
|
||||
# the sqrt() is just a heuristic, it seems to give periods that work better.
|
||||
param_prods.append(state["grad_times_step"].sqrt())
|
||||
|
||||
|
||||
param_prods = torch.stack(param_prods).to('cpu')
|
||||
# TEMP
|
||||
@ -1598,7 +1605,7 @@ def _test_eve_cain():
|
||||
estimate_period=500, stats_steps=100)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
|
||||
estimate_period=500, stats_steps=100)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
avg_loss = 0.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user