From 58eb49821916ca87338c2cc69887ac5621c01cc2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Aug 2021 15:35:00 +0800 Subject: [PATCH] Set the initial learning rate directly. --- .../ASR/conformer_ctc_madam_no_warmup/train.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py index fb74b5781..fdef575e4 100755 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py @@ -38,17 +38,14 @@ from icefall.utils import ( def create_madam( params, - model_size: int = 256, - factor: float = 2.0, - warm_step: int = 25000, + lr: float = 5e-4, min_target_rms: float = 0.05, limit_grad_factor: float = float("inf"), l2_period: int = 1, ): - initial_lr = warm_step ** (-0.5) optimizer = Madam( params, - lr=initial_lr, + lr=lr, betas=(0.9, 0.98), eps=1e-9, min_target_rms=min_target_rms, @@ -166,7 +163,6 @@ def get_params() -> AttributeDict: "reduction": "sum", "use_double_scores": True, # - "accum_grad": 1, "att_rate": 0.7, "attention_dim": 512, "nhead": 8, @@ -174,8 +170,7 @@ def get_params() -> AttributeDict: "is_espnet_structure": True, "mmi_loss": False, "use_feat_batchnorm": False, - "lr_factor": 2.0, - "warm_step": 30000, + "lr": 5e-4, } ) @@ -718,12 +713,7 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = create_madam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) + optimizer = create_madam(model.parameters(), lr=params.lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)