mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
First version I am running, of the speedup code
This commit is contained in:
parent
a6bf32c3e0
commit
e996f7f371
@ -22,6 +22,7 @@ import random
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from icefall import diagnostics # only for testing code
|
from icefall import diagnostics # only for testing code
|
||||||
|
import logging
|
||||||
|
|
||||||
class NeutralGradient(Optimizer):
|
class NeutralGradient(Optimizer):
|
||||||
"""
|
"""
|
||||||
@ -90,9 +91,9 @@ class NeutralGradient(Optimizer):
|
|||||||
estimate_period=2000,
|
estimate_period=2000,
|
||||||
stats_steps=200,
|
stats_steps=200,
|
||||||
param_pow=1.0,
|
param_pow=1.0,
|
||||||
lr_for_speedup=0.1,
|
lr_for_speedup=0.03,
|
||||||
speedup_first_step=500, # TEMP. Make this 3500? must be > warmup.
|
|
||||||
speedup_recalibrate_period=100,
|
speedup_recalibrate_period=100,
|
||||||
|
speedup_first_step=500,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
@ -117,6 +118,10 @@ class NeutralGradient(Optimizer):
|
|||||||
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
|
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
|
||||||
if not 0 < stats_steps < estimate_period:
|
if not 0 < stats_steps < estimate_period:
|
||||||
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
|
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
|
||||||
|
if not 0 < speedup_first_step:
|
||||||
|
raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step))
|
||||||
|
if not 0 < speedup_recalibrate_period:
|
||||||
|
raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period))
|
||||||
|
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
@ -132,9 +137,9 @@ class NeutralGradient(Optimizer):
|
|||||||
stats_steps=stats_steps,
|
stats_steps=stats_steps,
|
||||||
param_pow=param_pow,
|
param_pow=param_pow,
|
||||||
lr_for_speedup=lr_for_speedup,
|
lr_for_speedup=lr_for_speedup,
|
||||||
speedup_first_step=speedup_first_step,
|
|
||||||
speedup_recalibrate_period=speedup_recalibrate_period,
|
speedup_recalibrate_period=speedup_recalibrate_period,
|
||||||
group_step=0,
|
speedup_first_step=speedup_first_step,
|
||||||
|
speedup_step=-speedup_first_step,
|
||||||
)
|
)
|
||||||
super(NeutralGradient, self).__init__(params, defaults)
|
super(NeutralGradient, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -171,11 +176,11 @@ class NeutralGradient(Optimizer):
|
|||||||
speedup_first_step = group["speedup_first_step"]
|
speedup_first_step = group["speedup_first_step"]
|
||||||
speedup_recalibrate_period = group["speedup_recalibrate_period"]
|
speedup_recalibrate_period = group["speedup_recalibrate_period"]
|
||||||
|
|
||||||
group_step = group["group_step"]
|
speedup_step = group["speedup_step"]
|
||||||
group["group_step"] = group_step + 1
|
group["speedup_step"] = speedup_step + 1
|
||||||
|
|
||||||
recalibrate = (group_step >= speedup_first_step and
|
recalibrate = (speedup_step >= 0 and
|
||||||
(group_step - speedup_first_step) % speedup_recalibrate_period == 0)
|
speedup_step % speedup_recalibrate_period == 0)
|
||||||
if recalibrate:
|
if recalibrate:
|
||||||
# param_prods will be an array of products, one per parameter tensor, equal to:
|
# param_prods will be an array of products, one per parameter tensor, equal to:
|
||||||
# E[grad . step],
|
# E[grad . step],
|
||||||
@ -270,31 +275,26 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
step_period = state["step_period"]
|
step_period = state["step_period"]
|
||||||
n_cached_grads = state["n_cached_grads"]
|
n_cached_grads = state["n_cached_grads"]
|
||||||
if step_period > 1:
|
if step_period > 1 or n_cached_grads > 0:
|
||||||
cached_grad = state["cached_grad"]
|
cached_grad = state["cached_grad"]
|
||||||
cached_grad += grad
|
|
||||||
n_cached_grads += 1
|
n_cached_grads += 1
|
||||||
state["n_cached_grads"] = n_cached_grads
|
state["n_cached_grads"] = n_cached_grads
|
||||||
# the point of random_step_prob is to inject a little randomness
|
# the point of random_step_prob is to inject a little randomness
|
||||||
# into the timing of the updates, so they don't stay synchronized.
|
# into the timing of the updates, so they don't stay synchronized.
|
||||||
random_step_prob = 0.05 / step_period
|
random_step_prob = 0.05 / step_period
|
||||||
if n_cached_grads == step_period or random.random() < random_step_prob:
|
if n_cached_grads >= step_period or random.random() < random_step_prob:
|
||||||
# We will continue to the code below, and do a step
|
# We will continue to the code below, and do a step
|
||||||
grad = cached_grad
|
grad += cached_grad
|
||||||
|
cached_grad.zero_()
|
||||||
|
state["n_cached_grads"] = 0
|
||||||
else:
|
else:
|
||||||
|
cached_grad += grad
|
||||||
# Don't do a step this time.
|
# Don't do a step this time.
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# We are doing a step on every iteration.
|
# We are doing a step on every iteration.
|
||||||
n_cached_grads = 1 # will be used to scale gradients below.
|
n_cached_grads = 1 # will be used to scale gradients below.
|
||||||
|
|
||||||
# sqrt_n_grads will be used to scale down estimates of grad variances
|
|
||||||
# (in situations where the scale matters), to reflect the sizes of
|
|
||||||
# grads of individual
|
|
||||||
inv_sqrt_n_grads = n_cached_grads ** 0.5
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
@ -410,12 +410,14 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
if random.random() < 0.0001:
|
if random.random() < 0.0001:
|
||||||
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
||||||
max_abs = delta.abs().max()
|
#max_abs = delta.abs().max()
|
||||||
if max_abs > 1.0:
|
#if max_abs > 1.0:
|
||||||
print("Max_abs_change = ", max_abs)
|
# print("Max_abs_change = ", max_abs)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
if step % 10 == 0:
|
if step % 10 == 0:
|
||||||
p.clamp_(min=-param_max, max=param_max)
|
p.clamp_(min=-param_max, max=param_max)
|
||||||
|
|
||||||
|
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@ -551,6 +553,10 @@ class NeutralGradient(Optimizer):
|
|||||||
param_rel_eps*param_rel_eps * params_sq.mean())
|
param_rel_eps*param_rel_eps * params_sq.mean())
|
||||||
|
|
||||||
|
|
||||||
|
params_sq_norm = params_sq
|
||||||
|
if param_pow != 1.0:
|
||||||
|
params_sq_partnorm = params_sq
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
# Iterate 3 times, this should be enough to converge.
|
# Iterate 3 times, this should be enough to converge.
|
||||||
for dim in range(p.ndim):
|
for dim in range(p.ndim):
|
||||||
@ -561,8 +567,10 @@ class NeutralGradient(Optimizer):
|
|||||||
other_dims = [ i for i in range(p.ndim) if i != dim ]
|
other_dims = [ i for i in range(p.ndim) if i != dim ]
|
||||||
# Compute diagonal variance along this dimension
|
# Compute diagonal variance along this dimension
|
||||||
# (this is after normalizing previous dimensions' variance)
|
# (this is after normalizing previous dimensions' variance)
|
||||||
this_var = params_sq.mean(dim=other_dims, keepdim=True)
|
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
|
||||||
params_sq = params_sq / this_var
|
params_sq_norm = params_sq_norm / this_var
|
||||||
|
if param_pow != 1.0:
|
||||||
|
params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow)
|
||||||
|
|
||||||
this_var = this_var.reshape(size)
|
this_var = this_var.reshape(size)
|
||||||
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
|
this_scale = (this_var ** (param_pow * 0.5)).reshape(size)
|
||||||
@ -576,6 +584,17 @@ class NeutralGradient(Optimizer):
|
|||||||
# diagonalized space.
|
# diagonalized space.
|
||||||
proj *= this_scale.unsqueeze(-1)
|
proj *= this_scale.unsqueeze(-1)
|
||||||
|
|
||||||
|
if param_pow != 1.0:
|
||||||
|
# need to get the overall scale corret, as if we had param_pow == 1.0
|
||||||
|
scale = (params_sq_partnorm.mean() ** 0.5)
|
||||||
|
for dim in range(p.ndim):
|
||||||
|
size = p.shape[dim]
|
||||||
|
if size == 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
state[f"proj_{dim}"] *= scale
|
||||||
|
break
|
||||||
|
|
||||||
# Reset the squared-gradient stats because we have changed the basis.
|
# Reset the squared-gradient stats because we have changed the basis.
|
||||||
state["exp_avg_sq"].fill_(0.0)
|
state["exp_avg_sq"].fill_(0.0)
|
||||||
|
|
||||||
@ -789,13 +808,29 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
return P
|
return P
|
||||||
|
|
||||||
|
def reset_speedup(self):
|
||||||
|
"""
|
||||||
|
Reset the step_period so that we'll do a step each time, until the next
|
||||||
|
time
|
||||||
|
"""
|
||||||
|
for group in self.param_groups:
|
||||||
|
group["speedup_step"] = -group["speedup_first_step"]
|
||||||
|
for p in group["params"]:
|
||||||
|
state = self.state[p]
|
||||||
|
state["step_period"] = 1
|
||||||
|
|
||||||
def _recalibrate_speedup(self, group: dict):
|
def _recalibrate_speedup(self, group: dict):
|
||||||
param_prods = []
|
param_prods = []
|
||||||
speedup = group["lr_for_speedup"] / group["lr"]
|
speedup = group["lr_for_speedup"] / group["lr"]
|
||||||
if speedup > 10.0:
|
if speedup > 10.0:
|
||||||
speedup = 10.0
|
speedup = 10.0
|
||||||
if speedup < 2.0:
|
if speedup <= 1.0:
|
||||||
speedup = 2.0
|
for p in group["params"]:
|
||||||
|
state = self.state[p]
|
||||||
|
state["step_period"] = 1
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
_, beta2 = group["betas"]
|
_, beta2 = group["betas"]
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
@ -819,13 +854,16 @@ class NeutralGradient(Optimizer):
|
|||||||
param_prods.append(prod)
|
param_prods.append(prod)
|
||||||
|
|
||||||
param_prods = torch.stack(param_prods).to('cpu')
|
param_prods = torch.stack(param_prods).to('cpu')
|
||||||
|
# TEMP
|
||||||
|
param_prods = param_prods.sqrt()
|
||||||
avg_update_freq = 1.0 / speedup
|
avg_update_freq = 1.0 / speedup
|
||||||
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
|
param_freqs = param_prods * (avg_update_freq / param_prods.mean())
|
||||||
|
|
||||||
# Don't delay more than 40 steps for any parameter. Note: with beta=0.1,
|
# Don't delay more than 1 / (3*speedup) steps for any parameter, to
|
||||||
# this would mean an average delay of about 500 steps before the gradient
|
# ensure that no parameter waits super-long for an update. Note: with
|
||||||
# makes it to the parameter.
|
# beta=0.1, this would mean an average delay of about 30*speedup steps
|
||||||
min_freq = 1.0 / 40
|
# before the gradient makes it to the parameter.
|
||||||
|
min_freq = 1.0 / (3 * speedup)
|
||||||
|
|
||||||
# get the mean after applying limits of [min_freq..1] for the frequencies of
|
# get the mean after applying limits of [min_freq..1] for the frequencies of
|
||||||
# update
|
# update
|
||||||
@ -836,15 +874,23 @@ class NeutralGradient(Optimizer):
|
|||||||
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
|
param_freqs = param_freqs.clamp(min=min_freq, max=1.0)
|
||||||
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
|
param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down
|
||||||
|
|
||||||
|
actual_speedup = 1.0 / (1.0 / param_periods).mean()
|
||||||
|
|
||||||
param_prods = param_prods.tolist()
|
param_prods = param_prods.tolist()
|
||||||
param_freqs = param_freqs.tolist()
|
param_freqs = param_freqs.tolist()
|
||||||
param_periods = param_periods.tolist()
|
param_periods = param_periods.tolist()
|
||||||
|
|
||||||
print("speedup = ", speedup)
|
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||||
for i,p in enumerate(group["params"]):
|
print_info = True # random.random() < 0.1
|
||||||
|
i = 0
|
||||||
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
print(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
|
state = self.state[p]
|
||||||
|
state["step_period"] = param_periods[i]
|
||||||
|
if print_info:
|
||||||
|
logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.")
|
||||||
|
i += 1
|
||||||
|
|
||||||
# Do nothing more, as yet.
|
# Do nothing more, as yet.
|
||||||
|
|
||||||
@ -1507,6 +1553,8 @@ def _test_eve_cain():
|
|||||||
avg_loss = 0.0
|
avg_loss = 0.0
|
||||||
for epoch in range(150):
|
for epoch in range(150):
|
||||||
scheduler.step_epoch()
|
scheduler.step_epoch()
|
||||||
|
if epoch == 100 and iter in [2,3]:
|
||||||
|
optim.reset_speedup() # check it doesn't crash.
|
||||||
|
|
||||||
if epoch == 130:
|
if epoch == 130:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
@ -1531,7 +1579,8 @@ def _test_eve_cain():
|
|||||||
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
#scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
||||||
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
#scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
||||||
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
#scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
||||||
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
lr = scheduler.get_last_lr()[0]
|
||||||
|
print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
||||||
loss.log().backward()
|
loss.log().backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
@ -1562,6 +1611,7 @@ def _test_eve_cain():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_eve_cain()
|
_test_eve_cain()
|
||||||
|
@ -762,6 +762,16 @@ def train_one_epoch(
|
|||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
|
if params.batch_idx_train == params.model_warm_step:
|
||||||
|
# we're about to start using the pruned loss, which brings new
|
||||||
|
# modules into play, so reset the frequencies of update, to
|
||||||
|
# avoid possible instability.
|
||||||
|
try:
|
||||||
|
optimizer.reset_speedup()
|
||||||
|
logging.info("Reset speedup on optimizer")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -916,7 +926,9 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = NeutralGradient(model.parameters(), lr=params.initial_lr)
|
optimizer = NeutralGradient(model.parameters(),
|
||||||
|
lr=params.initial_lr,
|
||||||
|
lr_for_speedup=params.initial_lr)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user