mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement gradient clipping.
This commit is contained in:
parent
8f876b3f54
commit
8298333bd2
@ -40,6 +40,14 @@ class ScaledAdam(Optimizer):
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
clipping_scale: (e.g. 2.0)
|
||||
A scale for gradient-clipping: if specified, the normalized gradients
|
||||
over the whole model will be clipped to have 2-norm equal to
|
||||
`clipping_scale` times the median 2-norm over the most recent period
|
||||
of `clipping_update_period` minibatches. By "normalized gradients",
|
||||
we mean after multiplying by the rms parameter value for this tensor
|
||||
[for non-scalars]; this is appropriate because our update is scaled
|
||||
by this quantity.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
||||
Must satisfy 0 < beta <= beta2 < 1.
|
||||
size_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||
@ -57,12 +65,14 @@ class ScaledAdam(Optimizer):
|
||||
model has any parameters with numel() == 1).
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time
|
||||
in the update.
|
||||
in the update.
|
||||
clipping_update_period: if clipping_scale is specified, this is the period
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
@ -70,12 +80,14 @@ class ScaledAdam(Optimizer):
|
||||
param_max_rms=2.0,
|
||||
scalar_max=2.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
):
|
||||
|
||||
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
betas=betas,
|
||||
size_lr_scale=size_lr_scale,
|
||||
eps=eps,
|
||||
@ -83,6 +95,7 @@ class ScaledAdam(Optimizer):
|
||||
param_max_rms=param_max_rms,
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
@ -106,7 +119,9 @@ class ScaledAdam(Optimizer):
|
||||
|
||||
batch = True
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
|
||||
|
||||
for i,p in enumerate(group["params"]):
|
||||
state = self.state[p]
|
||||
|
||||
# Perform optimization step
|
||||
@ -118,7 +133,11 @@ class ScaledAdam(Optimizer):
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
self._step_one_batch(group, p, state)
|
||||
|
||||
if i == 0:
|
||||
clipping_scale = self._get_clipping_scale(group, p, state)
|
||||
|
||||
self._step_one_batch(group, p, state, clipping_scale)
|
||||
|
||||
|
||||
return loss
|
||||
@ -170,11 +189,79 @@ class ScaledAdam(Optimizer):
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
This function is only to be called for the first parameter in the group.
|
||||
"""
|
||||
clipping_scale = group["clipping_scale"]
|
||||
step = state["step"]
|
||||
if clipping_scale is None or step == 0:
|
||||
# no clipping. return early on step == 0 because the other
|
||||
# parameters' state won't have been initialize yet.
|
||||
return 1.0
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=p.device)
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == 1:
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2)
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if not "model_norms" in state:
|
||||
state["model_norms"] = torch.zeros(clipping_update_period,
|
||||
device=p.device)
|
||||
state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
# We don't reach here if step == 0 because we
|
||||
# would have returned above.
|
||||
sorted_norms = state["model_norms"].sort()[0].to('cpu')
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period
|
||||
if "num_clipped" in state else 0.0)
|
||||
state["num_clipped"] = 0
|
||||
quartiles = [ '%.3e' % x for x in quartiles ]
|
||||
logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}")
|
||||
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
model_norm_threshold = state["model_norm_threshold"]
|
||||
ans = min(1.0,
|
||||
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
return ans
|
||||
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
"""
|
||||
Do the step for parameter p.
|
||||
Args:
|
||||
@ -188,6 +275,8 @@ class ScaledAdam(Optimizer):
|
||||
beta1 = group["betas"][0]
|
||||
|
||||
grad = p.grad
|
||||
if clipping_scale != 1.0:
|
||||
grad = grad * clipping_scale
|
||||
step = state["step"]
|
||||
delta = state["delta"]
|
||||
|
||||
@ -692,7 +781,7 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ]
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||
elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user