mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Simplified optimizer, rework somet things..
This commit is contained in:
parent
0f5957394b
commit
c3169222ae
@ -27,7 +27,7 @@ class Eve(Optimizer):
|
|||||||
r"""
|
r"""
|
||||||
Implements Eve algorithm. This is a modified version of AdamW with a special
|
Implements Eve algorithm. This is a modified version of AdamW with a special
|
||||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||||
rms of the parameters approach a particular specified value (we suggest 0.1). This is
|
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||||
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
||||||
will be close to invariant to the absolute scale on the parameter matrix.
|
will be close to invariant to the absolute scale on the parameter matrix.
|
||||||
|
|
||||||
@ -43,10 +43,13 @@ class Eve(Optimizer):
|
|||||||
running averages of gradient and its square (default: (0.9, 0.999))
|
running averages of gradient and its square (default: (0.9, 0.999))
|
||||||
eps (float, optional): term added to the denominator to improve
|
eps (float, optional): term added to the denominator to improve
|
||||||
numerical stability (default: 1e-8)
|
numerical stability (default: 1e-8)
|
||||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
this value means that the weight would decay significantly after
|
||||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
about 3k minibatches. Is not multiplied by learning rate, but
|
||||||
(default: False)
|
is conditional on RMS-value of parameter being > target_rms.
|
||||||
|
target_rms (float, optional): target root-mean-square value of
|
||||||
|
parameters, if they fall below this we will stop applying weight decay.
|
||||||
|
|
||||||
|
|
||||||
.. _Adam\: A Method for Stochastic Optimization:
|
.. _Adam\: A Method for Stochastic Optimization:
|
||||||
https://arxiv.org/abs/1412.6980
|
https://arxiv.org/abs/1412.6980
|
||||||
@ -57,7 +60,7 @@ class Eve(Optimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
||||||
target_rms=0.1):
|
weight_decay=3e-4, target_rms=0.1):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
@ -67,9 +70,12 @@ class Eve(Optimizer):
|
|||||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
if not 0.0 <= betas[1] < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
|
if not 0 <= weight_decay <= 0.1:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
if not 0 < target_rms <= 10.0:
|
if not 0 < target_rms <= 10.0:
|
||||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
target_rms=target_rms)
|
target_rms=target_rms)
|
||||||
super(Eve, self).__init__(params, defaults)
|
super(Eve, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -94,6 +100,9 @@ class Eve(Optimizer):
|
|||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Perform optimization step
|
# Perform optimization step
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
@ -124,46 +133,21 @@ class Eve(Optimizer):
|
|||||||
|
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group['lr'] / bias_correction1
|
||||||
target_rms = group['target_rms']
|
target_rms = group['target_rms']
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
delta = exp_avg / denom
|
delta = exp_avg / denom
|
||||||
|
|
||||||
# we'll be doing: p += delta * step_size.
|
if p.numel() > 1:
|
||||||
# In the normal case delta_rms (the rms value of the elements of
|
# avoid applying this weight-decay on "scaling factors" (which are scalar).
|
||||||
# delta) will be very close to 1.0, but we compute it here so
|
is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5)))
|
||||||
# that if we don't use a particular parameter, its value won't
|
p.mul_(1 - (weight_decay * is_below_target_rms))
|
||||||
# shrink to zero.
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
# delta_var is the expected change in the variance of the parameter
|
|
||||||
# values, i.e. of E[param_elem^2], due to this step. It will
|
|
||||||
# be close to 1.
|
|
||||||
|
|
||||||
# Let us define:
|
|
||||||
# delta_var_from_update = (delta**2).mean() * step_size * step_size
|
|
||||||
|
|
||||||
# Suppose we are going to shrinkage with a small value epsilon (not the
|
|
||||||
# same as the eps above!), i.e. param *= (1-epsilon). Then
|
|
||||||
# if E[param_elem^2] == target_rms^2 (because we desire equilibrium when
|
|
||||||
# the RMS of the parameters equals target_rms), it follows that
|
|
||||||
# E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2),
|
|
||||||
# which we can put as:
|
|
||||||
# delta_var_from_shrinkage \simeq -2 epsilon target_rms^2.
|
|
||||||
# Setting delta_var_from_shrinkage = -delta_var_from_update
|
|
||||||
# because we want them to cancel,
|
|
||||||
# delta_var_from_update = 2 epsilon target_rms^2, or:
|
|
||||||
# epsilon = delta_var_from_update / (2 * target_rms^2)
|
|
||||||
# = (delta**2).mean() * 0.5 * (step_size / target_rms)**2.
|
|
||||||
# Note: step_size is close to the learning rate. For an example, if
|
|
||||||
# lr = 1.0e-04 and target_rms == 0.1, then in the normal case where
|
|
||||||
# (delta**2).mean() == 1, we will have:
|
|
||||||
# epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06.
|
|
||||||
# Note that this is close to the "traditional" value used for weight
|
|
||||||
# decay.
|
|
||||||
#
|
|
||||||
# this is the weight-decay amount...
|
|
||||||
#
|
|
||||||
# Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive
|
|
||||||
# frames being correlated. I have to figure out the justification.
|
|
||||||
weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1)))
|
|
||||||
|
|
||||||
p.mul_(1 - weight_decay)
|
|
||||||
p.add_(delta, alpha=-step_size)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
# Note on avg-change per epoch..
|
||||||
|
# suppose epoch is 4k iters.
|
||||||
|
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
|
||||||
|
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
|
||||||
|
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
|
||||||
|
#
|
||||||
|
# .. 6e-05 is 1/5 of that...
|
||||||
|
@ -152,17 +152,17 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-decay-steps",
|
"--lr-num-steps",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=3000,
|
||||||
help="The number of steps before we decay (halve) the learning rate",
|
help="Number of steps before we start to significantly decay the learning rate",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-lr-decays",
|
"--lr-power",
|
||||||
type=int,
|
type=float,
|
||||||
default=4,
|
default=0.5,
|
||||||
help="The total number of times we decay (halve) the learning rate"
|
help="Power in LR-setting rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -781,12 +781,10 @@ def run(rank, world_size, args):
|
|||||||
optimizer = Eve(
|
optimizer = Eve(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=params.initial_lr, betas=(0.9, 0.98),
|
lr=params.initial_lr, betas=(0.9, 0.98),
|
||||||
eps=1e-9, target_rms=0.1)
|
eps=1e-9, weight_decay=3e-04, target_rms=0.1)
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
[ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ],
|
lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power))
|
||||||
gamma=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user