Refactor optimizer (#1837)

* Print indexes of largest grad
This commit is contained in:
Han Zhu 2024-12-30 15:30:02 +08:00 committed by GitHub
parent 57e9f2a8db
commit 48088cb807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -121,6 +121,139 @@ class BatchedOptimizer(Optimizer):
p.copy_(stacked_params[i])
def basic_step(group, p, state, grad):
# computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet.
lr = group["lr"]
if p.numel() == p.shape[0]:
lr = lr * group["scalar_lr_scale"]
beta2 = group["betas"][1]
eps = group["eps"]
# p shape: (batch_size,) or (batch_size, 1, [1,..])
try:
exp_avg_sq = state[
"exp_avg_sq"
] # shape: (batch_size,) or (batch_size, 1, [1,..])
except KeyError:
exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
state["exp_avg_sq"] = exp_avg_sq
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# bias_correction2 is like in Adam.
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt().add_(eps)
return -lr * grad / denom
def scaling_step(group, p, state, grad):
delta = basic_step(group, p, state, grad)
if p.numel() == p.shape[0]:
return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.)
step = state["step"]
size_update_period = group["size_update_period"]
try:
param_rms = state["param_rms"]
scale_grads = state["scale_grads"]
scale_exp_avg_sq = state["scale_exp_avg_sq"]
except KeyError:
# we know p.ndim > 1 because we'd have returned above if not, so don't worry
# about the speial case of dim=[] that pytorch treats inconsistently.
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
param_rms = param_rms.to(torch.float)
scale_exp_avg_sq = torch.zeros_like(param_rms)
scale_grads = torch.zeros(
size_update_period, *param_rms.shape, dtype=torch.float, device=p.device
)
state["param_rms"] = param_rms
state["scale_grads"] = scale_grads
state["scale_exp_avg_sq"] = scale_exp_avg_sq
# on every step, update the gradient w.r.t. the scale of the parameter, we
# store these as a batch and periodically update the size (for speed only, to
# avoid too many operations).
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True
)
# periodically recompute the value of param_rms.
if step % size_update_period == size_update_period - 1:
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
param_min_rms = group["param_min_rms"]
# scale the step size by param_rms. This is the most important "scaling" part of
# ScaledAdam
delta *= param_rms.clamp(min=param_min_rms)
if step % size_update_period == size_update_period - 1 and step > 0:
# This block updates the size of parameter by adding a step ("delta") value in
# the direction of either shrinking or growing it.
beta2 = group["betas"][1]
size_lr = group["lr"] * group["scalar_lr_scale"]
param_max_rms = group["param_max_rms"]
eps = group["eps"]
batch_size = p.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr**size_step
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
)
is_too_small = param_rms < param_min_rms
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# The following may help prevent instability: don't allow the scale step to be too large in
# either direction.
scale_step.clamp_(min=-0.1, max=0.1)
# and ensure the parameter rms after update never exceeds param_max_rms.
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
delta.add_(p * scale_step)
return delta
def momentum_step(group, p, state, grad):
delta = scaling_step(group, p, state, grad)
beta1 = group["betas"][0]
try:
stored_delta = state["delta"]
except KeyError:
stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
state["delta"] = stored_delta
stored_delta.mul_(beta1)
stored_delta.add_(delta, alpha=(1 - beta1))
# we don't bother doing the "bias correction" part of Adam for beta1 because this is just
# an edge effect that affects the first 10 or so batches; and the effect of not doing it
# is just to do a slower update for the first few batches, which will help stability.
return stored_delta
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
@ -352,58 +485,26 @@ class ScaledAdam(BatchedOptimizer):
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state, clipping_scale)
try:
cur_step = state["step"]
except KeyError:
state["step"] = 0
cur_step = 0
grad = (
p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
)
p += momentum_step(group, p.detach(), state, grad)
if p.numel() == p.shape[0]: # scalar parameter
scalar_max = group["scalar_max"]
p.clamp_(min=-scalar_max, max=scalar_max)
state["step"] = cur_step + 1
return loss
def _init_state(self, group: dict, p: Tensor, state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
size_update_period = group["size_update_period"]
state["step"] = 0
kwargs = {"device": p.device, "dtype": p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(
size_update_period, *param_rms.shape, **kwargs
)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
@ -484,7 +585,7 @@ class ScaledAdam(BatchedOptimizer):
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.warn(
logging.warning(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
@ -499,8 +600,8 @@ class ScaledAdam(BatchedOptimizer):
ans = 0.0
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
if ans < 0.5:
logging.warning(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
@ -508,6 +609,7 @@ class ScaledAdam(BatchedOptimizer):
self._show_gradient_dominating_parameter(
tuples, tot_sumsq, group["scalar_lr_scale"]
)
self._show_param_with_unusual_grad(tuples)
if ans == 0.0:
for (p, state, param_names) in tuples:
@ -515,6 +617,55 @@ class ScaledAdam(BatchedOptimizer):
return ans
def _show_param_with_unusual_grad(
self,
tuples: List[Tuple[Tensor, dict, List[str]]],
):
"""
Print information about parameter which has the largest ratio of grad-on-this-batch
divided by normal grad size.
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
"""
largest_ratio = 0.0
largest_name = ""
# ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
ratios_names = []
for (p, state, batch_param_names) in tuples:
dims = list(range(1, p.ndim))
def mean(x):
# workaround for bad interface of torch's "mean" for when dims is the empty list.
if len(dims) > 0:
return x.mean(dim=dims)
else:
return x
grad_ratio = (
(mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
.sqrt()
.to("cpu")
)
ratios_names += zip(
grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
)
ratios_names = sorted(ratios_names, reverse=True)
ratios_names = ratios_names[:10]
ratios_names = [
(ratio, name, largest_index(tensor))
for (ratio, name, tensor) in ratios_names
]
logging.warning(
f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
)
def _show_gradient_dominating_parameter(
self,
tuples: List[Tuple[Tensor, dict, List[str]]],
@ -572,7 +723,7 @@ class ScaledAdam(BatchedOptimizer):
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.warn(
logging.warning(
f"Parameter dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
@ -581,182 +732,11 @@ class ScaledAdam(BatchedOptimizer):
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(
self, group: dict, p: Tensor, state: dict, clipping_scale: float
):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
beta1 = group["betas"][0]
grad = p.grad
if clipping_scale != 1.0:
grad *= clipping_scale
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True
)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_(
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
)
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
self._step_scalar(group, p, state)
else:
self._step(group, p, state)
state["step"] = step + 1
def _size_update(
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update
on `scale`.
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
"""
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr**size_step
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
)
is_too_small = param_rms < param_min_rms
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# and ensure the parameter rms after update never exceeds param_max_rms.
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1))
def _step(self, group: dict, p: Tensor, state: dict):
"""
This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
This function modifies p.
"""
grad = p.grad
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
param_min_rms = group["param_min_rms"]
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt()
denom += eps
grad = grad / denom
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"]
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta)
def largest_index(x: Tensor):
x = x.contiguous()
argmax = x.abs().argmax().item()
return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
class LRScheduler(object):
@ -807,7 +787,6 @@ class LRScheduler(object):
self.__dict__.update(state_dict)
self.base_lrs = base_lrs
def get_last_lr(self) -> List[float]:
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
@ -853,7 +832,7 @@ class LRScheduler(object):
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.warn(
logging.warning(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
@ -1184,7 +1163,7 @@ def _test_scaled_adam(hidden_dim: int):
if iter == 0:
optim = Eve(m.parameters(), lr=0.003)
elif iter == 1:
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()