First version I am running, of the speedup code

This commit is contained in:
Daniel Povey 2022-06-18 15:17:51 +08:00
parent a6bf32c3e0
commit e996f7f371
2 changed files with 97 additions and 35 deletions

View File

@ -22,6 +22,7 @@ import random
from torch import Tensor
from torch.optim import Optimizer
from icefall import diagnostics # only for testing code
import logging
class NeutralGradient(Optimizer):
"""
@ -90,9 +91,9 @@ class NeutralGradient(Optimizer):
estimate_period=2000,
stats_steps=200,
param_pow=1.0,
lr_for_speedup=0.1,
speedup_first_step=500, # TEMP. Make this 3500? must be > warmup.
lr_for_speedup=0.03,
speedup_recalibrate_period=100,
speedup_first_step=500,
):
if not 0.0 <= lr:
@ -117,6 +118,10 @@ class NeutralGradient(Optimizer):
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
if not 0 < stats_steps < estimate_period:
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
if not 0 < speedup_first_step:
raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step))
if not 0 < speedup_recalibrate_period:
raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period))
defaults = dict(
@ -132,9 +137,9 @@ class NeutralGradient(Optimizer):
stats_steps=stats_steps,
param_pow=param_pow,
lr_for_speedup=lr_for_speedup,
speedup_first_step=speedup_first_step,
speedup_recalibrate_period=speedup_recalibrate_period,
group_step=0,
speedup_first_step=speedup_first_step,
speedup_step=-speedup_first_step,
)
super(NeutralGradient, self).__init__(params, defaults)
@ -171,11 +176,11 @@ class NeutralGradient(Optimizer):
speedup_first_step = group["speedup_first_step"]
speedup_recalibrate_period = group["speedup_recalibrate_period"]
group_step = group["group_step"]
group["group_step"] = group_step + 1
speedup_step = group["speedup_step"]
group["speedup_step"] = speedup_step + 1
recalibrate = (group_step >= speedup_first_step and
(group_step - speedup_first_step) % speedup_recalibrate_period == 0)
recalibrate = (speedup_step >= 0 and
speedup_step % speedup_recalibrate_period == 0)
if recalibrate:
# param_prods will be an array of products, one per parameter tensor, equal to:
# E[grad . step],
@ -270,31 +275,26 @@ class NeutralGradient(Optimizer):
step_period = state["step_period"]
n_cached_grads = state["n_cached_grads"]
if step_period > 1:
if step_period > 1 or n_cached_grads > 0:
cached_grad = state["cached_grad"]
cached_grad += grad
n_cached_grads += 1
state["n_cached_grads"] = n_cached_grads
# the point of random_step_prob is to inject a little randomness
# into the timing of the updates, so they don't stay synchronized.
random_step_prob = 0.05 / step_period
if n_cached_grads == step_period or random.random() < random_step_prob:
if n_cached_grads >= step_period or random.random() < random_step_prob:
# We will continue to the code below, and do a step
grad = cached_grad
grad += cached_grad
cached_grad.zero_()
state["n_cached_grads"] = 0
else:
cached_grad += grad
# Don't do a step this time.
continue
else:
# We are doing a step on every iteration.
n_cached_grads = 1 # will be used to scale gradients below.
# sqrt_n_grads will be used to scale down estimates of grad variances
# (in situations where the scale matters), to reflect the sizes of
# grads of individual
inv_sqrt_n_grads = n_cached_grads ** 0.5
delta = state["delta"]
step = state["step"]
@ -410,12 +410,14 @@ class NeutralGradient(Optimizer):
if random.random() < 0.0001:
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
max_abs = delta.abs().max()
if max_abs > 1.0:
print("Max_abs_change = ", max_abs)
#max_abs = delta.abs().max()
#if max_abs > 1.0:
# print("Max_abs_change = ", max_abs)
p.add_(delta)
if step % 10 == 0:
p.clamp_(min=-param_max, max=param_max)
state["step"] += 1
return loss
@ -551,6 +553,10 @@ class NeutralGradient(Optimizer):
param_rel_eps*param_rel_eps * params_sq.mean())
params_sq_norm = params_sq
if param_pow != 1.0:
params_sq_partnorm = params_sq
for _ in range(3):
# Iterate 3 times, this should be enough to converge.
for dim in range(p.ndim):
@ -561,8 +567,10 @@ class NeutralGradient(Optimizer):
other_dims = [ i for i in range(p.ndim) if i != dim ]
# Compute diagonal variance along this dimension
# (this is after normalizing previous dimensions' variance)
this_var = params_sq.mean(dim=other_dims, keepdim=True)
params_sq = params_sq / this_var
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
params_sq_norm = params_sq_norm / this_var
if param_pow != 1.0:
params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow)
this_var = this_var.reshape(size)
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
@ -576,6 +584,17 @@ class NeutralGradient(Optimizer):
# diagonalized space.
proj *= this_scale.unsqueeze(-1)
if param_pow != 1.0:
# need to get the overall scale corret, as if we had param_pow == 1.0
scale = (params_sq_partnorm.mean() ** 0.5)
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1:
continue
else:
state[f"proj_{dim}"] *= scale
break
# Reset the squared-gradient stats because we have changed the basis.
state["exp_avg_sq"].fill_(0.0)
@ -789,13 +808,29 @@ class NeutralGradient(Optimizer):
return P
def reset_speedup(self):
"""
Reset the step_period so that we'll do a step each time, until the next
time
"""
for group in self.param_groups:
group["speedup_step"] = -group["speedup_first_step"]
for p in group["params"]:
state = self.state[p]
state["step_period"] = 1
def _recalibrate_speedup(self, group: dict):
param_prods = []
speedup = group["lr_for_speedup"] / group["lr"]
if speedup > 10.0:
speedup = 10.0
if speedup < 2.0:
speedup = 2.0
if speedup <= 1.0:
for p in group["params"]:
state = self.state[p]
state["step_period"] = 1
return
_, beta2 = group["betas"]
for p in group["params"]:
@ -819,13 +854,16 @@ class NeutralGradient(Optimizer):
param_prods.append(prod)
param_prods = torch.stack(param_prods).to('cpu')
# TEMP
param_prods = param_prods.sqrt()
avg_update_freq = 1.0 / speedup
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
# Don't delay more than 40 steps for any parameter. Note: with beta=0.1,
# this would mean an average delay of about 500 steps before the gradient
# makes it to the parameter.
min_freq = 1.0 / 40
# Don't delay more than 1 / (3*speedup) steps for any parameter, to
# ensure that no parameter waits super-long for an update. Note: with
# beta=0.1, this would mean an average delay of about 30*speedup steps
# before the gradient makes it to the parameter.
min_freq = 1.0 / (3 * speedup)
# get the mean after applying limits of [min_freq..1] for the frequencies of
# update
@ -836,15 +874,23 @@ class NeutralGradient(Optimizer):
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
actual_speedup = 1.0 / (1.0 / param_periods).mean()
param_prods = param_prods.tolist()
param_freqs = param_freqs.tolist()
param_periods = param_periods.tolist()
print("speedup = ", speedup)
for i,p in enumerate(group["params"]):
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
print_info = True # random.random() < 0.1
i = 0
for p in group["params"]:
if p.grad is None:
continue
print(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
state = self.state[p]
state["step_period"] = param_periods[i]
if print_info:
logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
i += 1
# Do nothing more, as yet.
@ -1507,6 +1553,8 @@ def _test_eve_cain():
avg_loss = 0.0
for epoch in range(150):
scheduler.step_epoch()
if epoch == 100 and iter in [2,3]:
optim.reset_speedup() # check it doesn't crash.
if epoch == 130:
opts = diagnostics.TensorDiagnosticOptions(
@ -1531,7 +1579,8 @@ def _test_eve_cain():
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
lr = scheduler.get_last_lr()[0]
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
loss.log().backward()
optim.step()
optim.zero_grad()
@ -1562,6 +1611,7 @@ def _test_eve_cain():
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_eve_cain()

View File

@ -762,6 +762,16 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
if params.batch_idx_train == params.model_warm_step:
# we're about to start using the pruned loss, which brings new
# modules into play, so reset the frequencies of update, to
# avoid possible instability.
try:
optimizer.reset_speedup()
logging.info("Reset speedup on optimizer")
except:
pass
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
@ -916,7 +926,9 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
optimizer = NeutralGradient(model.parameters(), lr=params.initial_lr)
optimizer = NeutralGradient(model.parameters(),
lr=params.initial_lr,
lr_for_speedup=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)