mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Simplified optimizer, rework somet things..
This commit is contained in:
parent
0f5957394b
commit
c3169222ae
@ -27,7 +27,7 @@ class Eve(Optimizer):
|
||||
r"""
|
||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular specified value (we suggest 0.1). This is
|
||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||
will be close to invariant to the absolute scale on the parameter matrix.
|
||||
|
||||
@ -43,10 +43,13 @@ class Eve(Optimizer):
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||
this value means that the weight would decay significantly after
|
||||
about 3k minibatches. Is not multiplied by learning rate, but
|
||||
is conditional on RMS-value of parameter being > target_rms.
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
@ -57,7 +60,7 @@ class Eve(Optimizer):
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
||||
target_rms=0.1):
|
||||
weight_decay=3e-4, target_rms=0.1):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
@ -67,9 +70,12 @@ class Eve(Optimizer):
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
target_rms=target_rms)
|
||||
super(Eve, self).__init__(params, defaults)
|
||||
|
||||
@ -94,6 +100,9 @@ class Eve(Optimizer):
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
|
||||
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
@ -124,46 +133,21 @@ class Eve(Optimizer):
|
||||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
target_rms = group['target_rms']
|
||||
weight_decay = group['weight_decay']
|
||||
delta = exp_avg / denom
|
||||
|
||||
# we'll be doing: p += delta * step_size.
|
||||
# In the normal case delta_rms (the rms value of the elements of
|
||||
# delta) will be very close to 1.0, but we compute it here so
|
||||
# that if we don't use a particular parameter, its value won't
|
||||
# shrink to zero.
|
||||
# delta_var is the expected change in the variance of the parameter
|
||||
# values, i.e. of E[param_elem^2], due to this step. It will
|
||||
# be close to 1.
|
||||
|
||||
# Let us define:
|
||||
# delta_var_from_update = (delta**2).mean() * step_size * step_size
|
||||
|
||||
# Suppose we are going to shrinkage with a small value epsilon (not the
|
||||
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
||||
# if E[param_elem^2] == target_rms^2 (because we desire equilibrium when
|
||||
# the RMS of the parameters equals target_rms), it follows that
|
||||
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
|
||||
# which we can put as:
|
||||
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
||||
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
||||
# because we want them to cancel,
|
||||
# delta_var_from_update = 2 epsilon target_rms^2, or:
|
||||
# epsilon = delta_var_from_update / (2 * target_rms^2)
|
||||
# = (delta**2).mean() * 0.5 * (step_size / target_rms)**2.
|
||||
# Note: step_size is close to the learning rate. For an example, if
|
||||
# lr = 1.0e-04 and target_rms == 0.1, then in the normal case where
|
||||
# (delta**2).mean() == 1, we will have:
|
||||
# epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06.
|
||||
# Note that this is close to the "traditional" value used for weight
|
||||
# decay.
|
||||
#
|
||||
# this is the weight-decay amount...
|
||||
#
|
||||
# Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive
|
||||
# frames being correlated. I have to figure out the justification.
|
||||
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1)))
|
||||
|
||||
p.mul_(1 - weight_decay)
|
||||
p.add_(delta, alpha=-step_size)
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors" (which are scalar).
|
||||
is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
|
||||
p.mul_(1 - (weight_decay * is_below_target_rms))
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
||||
# Note on avg-change per epoch..
|
||||
# suppose epoch is 4k iters.
|
||||
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
|
||||
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
|
||||
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
|
||||
#
|
||||
# .. 6e-05 is 1/5 of that...
|
||||
|
@ -152,17 +152,17 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-decay-steps",
|
||||
"--lr-num-steps",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="The number of steps before we decay (halve) the learning rate",
|
||||
default=3000,
|
||||
help="Number of steps before we start to significantly decay the learning rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-lr-decays",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The total number of times we decay (halve) the learning rate"
|
||||
"--lr-power",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Power in LR-setting rule",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -781,12 +781,10 @@ def run(rank, world_size, args):
|
||||
optimizer = Eve(
|
||||
model.parameters(),
|
||||
lr=params.initial_lr, betas=(0.9, 0.98),
|
||||
eps=1e-9, target_rms=0.1)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
eps=1e-9, weight_decay=3e-04, target_rms=0.1)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
[ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ],
|
||||
gamma=0.5)
|
||||
|
||||
lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power))
|
||||
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user