Replace warmup with lr scheduler.

This commit is contained in:
Fangjun Kuang 2021-08-16 00:00:53 +08:00
parent 0be42bef69
commit 02e409b6ce
2 changed files with 28 additions and 1080 deletions

View File

@ -17,8 +17,7 @@ import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
# from transformer import Noam from madam import Madam
from madam_no_warmup import Moam
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -37,6 +36,29 @@ from icefall.utils import (
) )
def create_madam(
params,
model_size: int = 256,
factor: float = 2.0,
warm_step: int = 25000,
min_target_rms: float = 0.05,
limit_grad_factor: float = float("inf"),
l2_period: int = 1,
):
initial_lr = warm_step ** (-0.5)
optimizer = Madam(
params,
lr=initial_lr,
betas=(0.9, 0.98),
eps=1e-9,
min_target_rms=min_target_rms,
limit_grad_factor=limit_grad_factor,
l2_period=l2_period,
)
return optimizer
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -86,8 +108,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as - lang_dir: It contains language related input files such as
"lexicon.txt" "lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used - feature_dim: The model input dim. It has to match the one used
in computing features. in computing features.
@ -155,6 +175,7 @@ def get_params() -> AttributeDict:
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": False, "use_feat_batchnorm": False,
"lr_factor": 2.0, "lr_factor": 2.0,
"warm_step": 30000,
} }
) )
@ -697,14 +718,15 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Moam( optimizer = create_madam(
model.parameters(), model.parameters(),
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step,
) )
scheduler = torch.optim.lr_scheduler.LambdaLR( scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda ep: 1.0 if ep < 3 else 0.7 ** (ep - 2) optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)
) )
if checkpoints and checkpoints["optimizer"]: if checkpoints and checkpoints["optimizer"]:
@ -720,8 +742,6 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
# LR scheduler can hold multiple learning rates for multiple parameter groups;
# For now we report just the first LR which we assume concerns most of the parameters.
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar(