mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
First version I am running, of the speedup code
This commit is contained in:
parent
a6bf32c3e0
commit
e996f7f371
@ -22,6 +22,7 @@ import random
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from icefall import diagnostics # only for testing code
|
||||
import logging
|
||||
|
||||
class NeutralGradient(Optimizer):
|
||||
"""
|
||||
@ -90,9 +91,9 @@ class NeutralGradient(Optimizer):
|
||||
estimate_period=2000,
|
||||
stats_steps=200,
|
||||
param_pow=1.0,
|
||||
lr_for_speedup=0.1,
|
||||
speedup_first_step=500, # TEMP. Make this 3500? must be > warmup.
|
||||
lr_for_speedup=0.03,
|
||||
speedup_recalibrate_period=100,
|
||||
speedup_first_step=500,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
@ -117,6 +118,10 @@ class NeutralGradient(Optimizer):
|
||||
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
|
||||
if not 0 < stats_steps < estimate_period:
|
||||
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
|
||||
if not 0 < speedup_first_step:
|
||||
raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step))
|
||||
if not 0 < speedup_recalibrate_period:
|
||||
raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period))
|
||||
|
||||
|
||||
defaults = dict(
|
||||
@ -132,9 +137,9 @@ class NeutralGradient(Optimizer):
|
||||
stats_steps=stats_steps,
|
||||
param_pow=param_pow,
|
||||
lr_for_speedup=lr_for_speedup,
|
||||
speedup_first_step=speedup_first_step,
|
||||
speedup_recalibrate_period=speedup_recalibrate_period,
|
||||
group_step=0,
|
||||
speedup_first_step=speedup_first_step,
|
||||
speedup_step=-speedup_first_step,
|
||||
)
|
||||
super(NeutralGradient, self).__init__(params, defaults)
|
||||
|
||||
@ -171,11 +176,11 @@ class NeutralGradient(Optimizer):
|
||||
speedup_first_step = group["speedup_first_step"]
|
||||
speedup_recalibrate_period = group["speedup_recalibrate_period"]
|
||||
|
||||
group_step = group["group_step"]
|
||||
group["group_step"] = group_step + 1
|
||||
speedup_step = group["speedup_step"]
|
||||
group["speedup_step"] = speedup_step + 1
|
||||
|
||||
recalibrate = (group_step >= speedup_first_step and
|
||||
(group_step - speedup_first_step) % speedup_recalibrate_period == 0)
|
||||
recalibrate = (speedup_step >= 0 and
|
||||
speedup_step % speedup_recalibrate_period == 0)
|
||||
if recalibrate:
|
||||
# param_prods will be an array of products, one per parameter tensor, equal to:
|
||||
# E[grad . step],
|
||||
@ -270,31 +275,26 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
step_period = state["step_period"]
|
||||
n_cached_grads = state["n_cached_grads"]
|
||||
if step_period > 1:
|
||||
if step_period > 1 or n_cached_grads > 0:
|
||||
cached_grad = state["cached_grad"]
|
||||
cached_grad += grad
|
||||
n_cached_grads += 1
|
||||
state["n_cached_grads"] = n_cached_grads
|
||||
# the point of random_step_prob is to inject a little randomness
|
||||
# into the timing of the updates, so they don't stay synchronized.
|
||||
random_step_prob = 0.05 / step_period
|
||||
if n_cached_grads == step_period or random.random() < random_step_prob:
|
||||
if n_cached_grads >= step_period or random.random() < random_step_prob:
|
||||
# We will continue to the code below, and do a step
|
||||
grad = cached_grad
|
||||
grad += cached_grad
|
||||
cached_grad.zero_()
|
||||
state["n_cached_grads"] = 0
|
||||
else:
|
||||
cached_grad += grad
|
||||
# Don't do a step this time.
|
||||
continue
|
||||
else:
|
||||
# We are doing a step on every iteration.
|
||||
n_cached_grads = 1 # will be used to scale gradients below.
|
||||
|
||||
# sqrt_n_grads will be used to scale down estimates of grad variances
|
||||
# (in situations where the scale matters), to reflect the sizes of
|
||||
# grads of individual
|
||||
inv_sqrt_n_grads = n_cached_grads ** 0.5
|
||||
|
||||
|
||||
|
||||
delta = state["delta"]
|
||||
step = state["step"]
|
||||
|
||||
@ -410,12 +410,14 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
if random.random() < 0.0001:
|
||||
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
||||
max_abs = delta.abs().max()
|
||||
if max_abs > 1.0:
|
||||
print("Max_abs_change = ", max_abs)
|
||||
#max_abs = delta.abs().max()
|
||||
#if max_abs > 1.0:
|
||||
# print("Max_abs_change = ", max_abs)
|
||||
p.add_(delta)
|
||||
if step % 10 == 0:
|
||||
p.clamp_(min=-param_max, max=param_max)
|
||||
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
return loss
|
||||
@ -551,6 +553,10 @@ class NeutralGradient(Optimizer):
|
||||
param_rel_eps*param_rel_eps * params_sq.mean())
|
||||
|
||||
|
||||
params_sq_norm = params_sq
|
||||
if param_pow != 1.0:
|
||||
params_sq_partnorm = params_sq
|
||||
|
||||
for _ in range(3):
|
||||
# Iterate 3 times, this should be enough to converge.
|
||||
for dim in range(p.ndim):
|
||||
@ -561,8 +567,10 @@ class NeutralGradient(Optimizer):
|
||||
other_dims = [ i for i in range(p.ndim) if i != dim ]
|
||||
# Compute diagonal variance along this dimension
|
||||
# (this is after normalizing previous dimensions' variance)
|
||||
this_var = params_sq.mean(dim=other_dims, keepdim=True)
|
||||
params_sq = params_sq / this_var
|
||||
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
|
||||
params_sq_norm = params_sq_norm / this_var
|
||||
if param_pow != 1.0:
|
||||
params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow)
|
||||
|
||||
this_var = this_var.reshape(size)
|
||||
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
|
||||
@ -576,6 +584,17 @@ class NeutralGradient(Optimizer):
|
||||
# diagonalized space.
|
||||
proj *= this_scale.unsqueeze(-1)
|
||||
|
||||
if param_pow != 1.0:
|
||||
# need to get the overall scale corret, as if we had param_pow == 1.0
|
||||
scale = (params_sq_partnorm.mean() ** 0.5)
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
else:
|
||||
state[f"proj_{dim}"] *= scale
|
||||
break
|
||||
|
||||
# Reset the squared-gradient stats because we have changed the basis.
|
||||
state["exp_avg_sq"].fill_(0.0)
|
||||
|
||||
@ -789,13 +808,29 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
return P
|
||||
|
||||
def reset_speedup(self):
|
||||
"""
|
||||
Reset the step_period so that we'll do a step each time, until the next
|
||||
time
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
group["speedup_step"] = -group["speedup_first_step"]
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
state["step_period"] = 1
|
||||
|
||||
def _recalibrate_speedup(self, group: dict):
|
||||
param_prods = []
|
||||
speedup = group["lr_for_speedup"] / group["lr"]
|
||||
if speedup > 10.0:
|
||||
speedup = 10.0
|
||||
if speedup < 2.0:
|
||||
speedup = 2.0
|
||||
if speedup <= 1.0:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
state["step_period"] = 1
|
||||
return
|
||||
|
||||
|
||||
_, beta2 = group["betas"]
|
||||
|
||||
for p in group["params"]:
|
||||
@ -819,13 +854,16 @@ class NeutralGradient(Optimizer):
|
||||
param_prods.append(prod)
|
||||
|
||||
param_prods = torch.stack(param_prods).to('cpu')
|
||||
# TEMP
|
||||
param_prods = param_prods.sqrt()
|
||||
avg_update_freq = 1.0 / speedup
|
||||
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
|
||||
|
||||
# Don't delay more than 40 steps for any parameter. Note: with beta=0.1,
|
||||
# this would mean an average delay of about 500 steps before the gradient
|
||||
# makes it to the parameter.
|
||||
min_freq = 1.0 / 40
|
||||
# Don't delay more than 1 / (3*speedup) steps for any parameter, to
|
||||
# ensure that no parameter waits super-long for an update. Note: with
|
||||
# beta=0.1, this would mean an average delay of about 30*speedup steps
|
||||
# before the gradient makes it to the parameter.
|
||||
min_freq = 1.0 / (3 * speedup)
|
||||
|
||||
# get the mean after applying limits of [min_freq..1] for the frequencies of
|
||||
# update
|
||||
@ -836,15 +874,23 @@ class NeutralGradient(Optimizer):
|
||||
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
|
||||
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
|
||||
|
||||
actual_speedup = 1.0 / (1.0 / param_periods).mean()
|
||||
|
||||
param_prods = param_prods.tolist()
|
||||
param_freqs = param_freqs.tolist()
|
||||
param_periods = param_periods.tolist()
|
||||
|
||||
print("speedup = ", speedup)
|
||||
for i,p in enumerate(group["params"]):
|
||||
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
print_info = True # random.random() < 0.1
|
||||
i = 0
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
print(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
|
||||
state = self.state[p]
|
||||
state["step_period"] = param_periods[i]
|
||||
if print_info:
|
||||
logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
|
||||
i += 1
|
||||
|
||||
# Do nothing more, as yet.
|
||||
|
||||
@ -1507,6 +1553,8 @@ def _test_eve_cain():
|
||||
avg_loss = 0.0
|
||||
for epoch in range(150):
|
||||
scheduler.step_epoch()
|
||||
if epoch == 100 and iter in [2,3]:
|
||||
optim.reset_speedup() # check it doesn't crash.
|
||||
|
||||
if epoch == 130:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
@ -1531,7 +1579,8 @@ def _test_eve_cain():
|
||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||
loss.log().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
@ -1562,6 +1611,7 @@ def _test_eve_cain():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_eve_cain()
|
||||
|
@ -762,6 +762,16 @@ def train_one_epoch(
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
if params.batch_idx_train == params.model_warm_step:
|
||||
# we're about to start using the pruned loss, which brings new
|
||||
# modules into play, so reset the frequencies of update, to
|
||||
# avoid possible instability.
|
||||
try:
|
||||
optimizer.reset_speedup()
|
||||
logging.info("Reset speedup on optimizer")
|
||||
except:
|
||||
pass
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
@ -916,7 +926,9 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = NeutralGradient(model.parameters(), lr=params.initial_lr)
|
||||
optimizer = NeutralGradient(model.parameters(),
|
||||
lr=params.initial_lr,
|
||||
lr_for_speedup=params.initial_lr)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user