Make lrate rule more symmetric

This commit is contained in:
Daniel Povey 2022-04-10 13:25:40 +08:00
parent 4d41ee0caa
commit da50525ca5

View File

@ -148,22 +148,22 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate",
help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
"--lr-begin-steps",
"--lr-steps",
type=float,
default=5000,
help="Number of steps that affects how rapidly the learning rate initially decreases"
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this."""
)
parser.add_argument(
"--lr-epochs",
type=float,
default=-1,
help="""Number of epochs for purposes of the learning-rate schedule;
if -1, will be set the same as --num-epochs
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
"""
)
@ -788,11 +788,10 @@ def run(rank, world_size, args):
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
# here to avoid crash in constructor.
epoch = 0
lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 *
((epoch + lr_epochs) / lr_epochs) ** -0.5))
lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 *
(((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))
if checkpoints and "optimizer" in checkpoints: