mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Make lrate rule more symmetric
This commit is contained in:
parent
4d41ee0caa
commit
da50525ca5
@ -148,22 +148,22 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate",
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-begin-steps",
|
||||
"--lr-steps",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="Number of steps that affects how rapidly the learning rate initially decreases"
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
We suggest not to change this."""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=-1,
|
||||
help="""Number of epochs for purposes of the learning-rate schedule;
|
||||
if -1, will be set the same as --num-epochs
|
||||
default=6,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
"""
|
||||
)
|
||||
|
||||
@ -788,11 +788,10 @@ def run(rank, world_size, args):
|
||||
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
|
||||
# here to avoid crash in constructor.
|
||||
epoch = 0
|
||||
lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 *
|
||||
((epoch + lr_epochs) / lr_epochs) ** -0.5))
|
||||
lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 *
|
||||
(((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))
|
||||
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user