Implement gradient clipping.

This commit is contained in:
Daniel Povey 2022-09-16 16:52:46 +08:00
parent 8f876b3f54
commit 8298333bd2

View File

@ -40,6 +40,14 @@ class ScaledAdam(Optimizer):
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the
@ -57,12 +65,14 @@ class ScaledAdam(Optimizer):
model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
"""
def __init__(
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
size_lr_scale=0.1,
eps=1.0e-08,
@ -70,12 +80,14 @@ class ScaledAdam(Optimizer):
param_max_rms=2.0,
scalar_max=2.0,
size_update_period=4,
clipping_update_period=100,
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
betas=betas,
size_lr_scale=size_lr_scale,
eps=eps,
@ -83,6 +95,7 @@ class ScaledAdam(Optimizer):
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults)
@ -106,7 +119,9 @@ class ScaledAdam(Optimizer):
batch = True
for group in self.param_groups:
for p in group["params"]:
for i,p in enumerate(group["params"]):
state = self.state[p]
# Perform optimization step
@ -118,7 +133,11 @@ class ScaledAdam(Optimizer):
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state)
if i == 0:
clipping_scale = self._get_clipping_scale(group, p, state)
self._step_one_batch(group, p, state, clipping_scale)
return loss
@ -170,11 +189,79 @@ class ScaledAdam(Optimizer):
p, memory_format=torch.preserve_format
)
def _get_clipping_scale(self,
group: dict,
p: Tensor,
state: dict) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
This function is only to be called for the first parameter in the group.
"""
clipping_scale = group["clipping_scale"]
step = state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
# parameters' state won't have been initialize yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=p.device)
for p in group["params"]:
state = self.state[p]
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == 1:
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
tot_norm = tot_sumsq.sqrt()
if not "model_norms" in state:
state["model_norms"] = torch.zeros(clipping_update_period,
device=p.device)
state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
# We don't reach here if step == 0 because we
# would have returned above.
sorted_norms = state["model_norms"].sort()[0].to('cpu')
quartiles = []
for n in range(0, 5):
index = min(clipping_update_period - 1,
(clipping_update_period // 4) * n)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
state["model_norm_threshold"] = threshold
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
if "num_clipped" in state else 0.0)
state["num_clipped"] = 0
quartiles = [ '%.3e' % x for x in quartiles ]
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
model_norm_threshold = state["model_norm_threshold"]
ans = min(1.0,
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
state["num_clipped"] += 1
if ans < 0.1:
logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
return ans
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict):
state: dict,
clipping_scale: float):
"""
Do the step for parameter p.
Args:
@ -188,6 +275,8 @@ class ScaledAdam(Optimizer):
beta1 = group["betas"][0]
grad = p.grad
if clipping_scale != 1.0:
grad = grad * clipping_scale
step = state["step"]
delta = state["delta"]
@ -692,7 +781,7 @@ def _test_scaled_adam(hidden_dim: int):
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03)
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)