mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Implement 2o schedule
This commit is contained in:
parent
db72aee1f0
commit
4d41ee0caa
@ -100,9 +100,6 @@ class Eve(Optimizer):
|
|||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Perform optimization step
|
# Perform optimization step
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
@ -144,12 +141,3 @@ class Eve(Optimizer):
|
|||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# Note on avg-change per epoch..
|
|
||||||
# suppose epoch is 4k iters.
|
|
||||||
# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1,
|
|
||||||
# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch)
|
|
||||||
# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04.
|
|
||||||
# Suggested lr_schedule?
|
|
||||||
#
|
|
||||||
# .. 6e-05 is 1/5 of that...
|
|
||||||
|
@ -154,15 +154,15 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-begin-steps",
|
"--lr-begin-steps",
|
||||||
type=float,
|
type=float,
|
||||||
default=25000,
|
default=5000,
|
||||||
help="Number of steps that affects how rapidly the learning rate initially decreases"
|
help="Number of steps that affects how rapidly the learning rate initially decreases"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-end-epochs",
|
"--lr-epochs",
|
||||||
type=float,
|
type=float,
|
||||||
default=-1,
|
default=-1,
|
||||||
help="""Number of epochs that affects how rapidly the learning rate finally decreases;
|
help="""Number of epochs for purposes of the learning-rate schedule;
|
||||||
if -1, will be set the same as --num-epochs
|
if -1, will be set the same as --num-epochs
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -784,15 +784,15 @@ def run(rank, world_size, args):
|
|||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=params.initial_lr)
|
lr=params.initial_lr)
|
||||||
|
|
||||||
# The `epoch` variable in the lambda expression binds to the value below
|
# The `epoch` variable in the lambda expression picks up to the value below
|
||||||
# in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0
|
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
|
||||||
# here to avoid crash in constructor.
|
# here to avoid crash in constructor.
|
||||||
epoch = 0
|
epoch = 0
|
||||||
lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs
|
lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 *
|
lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 *
|
||||||
((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0))
|
((epoch + lr_epochs) / lr_epochs) ** -0.5))
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user