Set new scheduler

This commit is contained in:
Daniel Povey 2022-04-08 16:10:06 +08:00
parent 61486a0f76
commit 6ee32cf7af

View File

@ -28,15 +28,17 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 \ --max-duration 300 \
--initial-lr 0.002 \ --initial-lr 0.003 \
--lr-decay-steps 10000 \ --lr-begin-steps 20000 \
--num-lr-decays 4 --lr-end-steps 50000
""" """
import argparse import argparse
import logging import logging
import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -147,22 +149,22 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.002, default=0.003,
help="The initial learning rate", help="The initial learning rate",
) )
parser.add_argument( parser.add_argument(
"--lr-num-steps", "--lr-begin-steps",
type=float, type=float,
default=3000, default=20000,
help="Number of steps before we start to significantly decay the learning rate", help="Number of steps that affects how rapidly the learning rate initially decreases"
) )
parser.add_argument( parser.add_argument(
"--lr-power", "--lr-end-steps",
type=float, type=float,
default=0.75, default=50000,
help="Power in LR-setting rule", help="Number of steps that affects how rapidly the learning rate finally decreases"
) )
parser.add_argument( parser.add_argument(
@ -783,7 +785,8 @@ def run(rank, world_size, args):
lr=params.initial_lr) lr=params.initial_lr)
scheduler = torch.optim.lr_scheduler.LambdaLR( scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, optimizer,
lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 *
math.exp(-step / params.lr_end_steps)))
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints: