Simplified optimizer, rework somet things..

This commit is contained in:
Daniel Povey 2022-04-05 13:23:02 +08:00
parent 0f5957394b
commit c3169222ae
2 changed files with 39 additions and 57 deletions

View File

@ -27,7 +27,7 @@ class Eve(Optimizer):
r""" r"""
Implements Eve algorithm. This is a modified version of AdamW with a special Implements Eve algorithm. This is a modified version of AdamW with a special
way of setting the weight-decay / shrinkage-factor, which is designed to make the way of setting the weight-decay / shrinkage-factor, which is designed to make the
rms of the parameters approach a particular specified value (we suggest 0.1). This is rms of the parameters approach a particular target_rms (default: 0.1). This is
for use with networks with 'scaled' versions of modules (see scaling.py), which for use with networks with 'scaled' versions of modules (see scaling.py), which
will be close to invariant to the absolute scale on the parameter matrix. will be close to invariant to the absolute scale on the parameter matrix.
@ -43,10 +43,13 @@ class Eve(Optimizer):
running averages of gradient and its square (default: (0.9, 0.999)) running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8) numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2) weight_decay (float, optional): weight decay coefficient (default: 3e-4;
amsgrad (boolean, optional): whether to use the AMSGrad variant of this this value means that the weight would decay significantly after
algorithm from the paper `On the Convergence of Adam and Beyond`_ about 3k minibatches. Is not multiplied by learning rate, but
(default: False) is conditional on RMS-value of parameter being > target_rms.
target_rms (float, optional): target root-mean-square value of
parameters, if they fall below this we will stop applying weight decay.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
@ -57,7 +60,7 @@ class Eve(Optimizer):
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
target_rms=0.1): weight_decay=3e-4, target_rms=0.1):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
@ -67,9 +70,12 @@ class Eve(Optimizer):
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0 <= weight_decay <= 0.1:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0 < target_rms <= 10.0: if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms)) raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
target_rms=target_rms) target_rms=target_rms)
super(Eve, self).__init__(params, defaults) super(Eve, self).__init__(params, defaults)
@ -94,6 +100,9 @@ class Eve(Optimizer):
if p.grad is None: if p.grad is None:
continue continue
# Perform optimization step # Perform optimization step
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
@ -124,46 +133,21 @@ class Eve(Optimizer):
step_size = group['lr'] / bias_correction1 step_size = group['lr'] / bias_correction1
target_rms = group['target_rms'] target_rms = group['target_rms']
weight_decay = group['weight_decay']
delta = exp_avg / denom delta = exp_avg / denom
# we'll be doing: p += delta * step_size. if p.numel() > 1:
# In the normal case delta_rms (the rms value of the elements of # avoid applying this weight-decay on "scaling factors" (which are scalar).
# delta) will be very close to 1.0, but we compute it here so is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
# that if we don't use a particular parameter, its value won't p.mul_(1 - (weight_decay * is_below_target_rms))
# shrink to zero. p.addcdiv_(exp_avg, denom, value=-step_size)
# delta_var is the expected change in the variance of the parameter
# values, i.e. of E[param_elem^2], due to this step. It will
# be close to 1.
# Let us define:
# delta_var_from_update = (delta**2).mean() * step_size * step_size
# Suppose we are going to shrinkage with a small value epsilon (not the
# same as the eps above!), i.e. param *= (1-epsilon). Then
# if E[param_elem^2] == target_rms^2 (because we desire equilibrium when
# the RMS of the parameters equals target_rms), it follows that
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
# which we can put as:
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
# Setting delta_var_from_shrinkage = -delta_var_from_update
# because we want them to cancel,
# delta_var_from_update = 2 epsilon target_rms^2, or:
# epsilon = delta_var_from_update / (2 * target_rms^2)
# = (delta**2).mean() * 0.5 * (step_size / target_rms)**2.
# Note: step_size is close to the learning rate. For an example, if
# lr = 1.0e-04 and target_rms == 0.1, then in the normal case where
# (delta**2).mean() == 1, we will have:
# epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06.
# Note that this is close to the "traditional" value used for weight
# decay.
#
# this is the weight-decay amount...
#
# Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive
# frames being correlated. I have to figure out the justification.
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1)))
p.mul_(1 - weight_decay)
p.add_(delta, alpha=-step_size)
return loss return loss
# Note on avg-change per epoch..
# suppose epoch is 4k iters.
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
#
# .. 6e-05 is 1/5 of that...

View File

@ -152,17 +152,17 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lr-decay-steps", "--lr-num-steps",
type=float, type=float,
default=5000, default=3000,
help="The number of steps before we decay (halve) the learning rate", help="Number of steps before we start to significantly decay the learning rate",
) )
parser.add_argument( parser.add_argument(
"--num-lr-decays", "--lr-power",
type=int, type=float,
default=4, default=0.5,
help="The total number of times we decay (halve) the learning rate" help="Power in LR-setting rule",
) )
parser.add_argument( parser.add_argument(
@ -781,12 +781,10 @@ def run(rank, world_size, args):
optimizer = Eve( optimizer = Eve(
model.parameters(), model.parameters(),
lr=params.initial_lr, betas=(0.9, 0.98), lr=params.initial_lr, betas=(0.9, 0.98),
eps=1e-9, target_rms=0.1) eps=1e-9, weight_decay=3e-04, target_rms=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, optimizer,
[ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power))
gamma=0.5)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints: