Replace warmup with lr scheduler.

This commit is contained in:
Fangjun Kuang 2021-08-15 22:59:51 +08:00
parent 21292066ec
commit 0be42bef69
2 changed files with 1087 additions and 5 deletions

File diff suppressed because it is too large Load Diff

View File

@ -18,7 +18,7 @@ from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
# from transformer import Noam # from transformer import Noam
from madam import Moam from madam_no_warmup import Moam
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -155,7 +155,6 @@ def get_params() -> AttributeDict:
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": False, "use_feat_batchnorm": False,
"lr_factor": 2.0, "lr_factor": 2.0,
"warm_step": 30000,
} }
) )
@ -702,12 +701,18 @@ def run(rank, world_size, args):
model.parameters(), model.parameters(),
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step,
) )
if checkpoints: scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda ep: 1.0 if ep < 3 else 0.7 ** (ep - 2)
)
if checkpoints and checkpoints["optimizer"]:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and checkpoints["scheduler"]:
scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
@ -715,7 +720,9 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate # LR scheduler can hold multiple learning rates for multiple parameter groups;
# For now we report just the first LR which we assume concerns most of the parameters.
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train "train/learning_rate", cur_lr, params.batch_idx_train
@ -738,10 +745,13 @@ def run(rank, world_size, args):
world_size=world_size, world_size=world_size,
) )
scheduler.step()
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
rank=rank, rank=rank,
) )