From da50525ca5f1cf5bb655adcac0e8ad5231aa7b5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 13:25:40 +0800 Subject: [PATCH] Make lrate rule more symmetric --- .../ASR/pruned_transducer_stateless2/train.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a114dd8f1..a8aaa4dde 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -148,22 +148,22 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( - "--lr-begin-steps", + "--lr-steps", type=float, default=5000, - help="Number of steps that affects how rapidly the learning rate initially decreases" + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""" ) parser.add_argument( "--lr-epochs", type=float, - default=-1, - help="""Number of epochs for purposes of the learning-rate schedule; - if -1, will be set the same as --num-epochs + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. """ ) @@ -788,11 +788,10 @@ def run(rank, world_size, args): # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 # here to avoid crash in constructor. epoch = 0 - lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 * - ((epoch + lr_epochs) / lr_epochs) ** -0.5)) + lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * + (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)) if checkpoints and "optimizer" in checkpoints: