diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 038469282..92509f4ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -154,15 +154,17 @@ def get_parser(): parser.add_argument( "--lr-begin-steps", type=float, - default=20000, + default=25000, help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( "--lr-end-epochs", type=float, - default=10, - help="Number of epochs that affects how rapidly the learning rate finally decreases" + default=-1, + help="""Number of epochs that affects how rapidly the learning rate finally decreases; + if -1, will be set the same as --num-epochs + """ ) parser.add_argument( @@ -783,12 +785,14 @@ def run(rank, world_size, args): lr=params.initial_lr) # The `epoch` variable in the lambda expression binds to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. + # in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0 + # here to avoid crash in constructor. epoch = 0 + lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - math.exp(-epoch / params.lr_end_epochs))) + ((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0)) if checkpoints and "optimizer" in checkpoints: