mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Replace warmup with lr scheduler.
This commit is contained in:
parent
21292066ec
commit
0be42bef69
1072
egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py
Normal file
1072
egs/librispeech/ASR/conformer_ctc_madam_no_warmup/madam_no_warmup.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -18,7 +18,7 @@ from conformer import Conformer
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
|
||||||
# from transformer import Noam
|
# from transformer import Noam
|
||||||
from madam import Moam
|
from madam_no_warmup import Moam
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -155,7 +155,6 @@ def get_params() -> AttributeDict:
|
|||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
"use_feat_batchnorm": False,
|
"use_feat_batchnorm": False,
|
||||||
"lr_factor": 2.0,
|
"lr_factor": 2.0,
|
||||||
"warm_step": 30000,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -702,12 +701,18 @@ def run(rank, world_size, args):
|
|||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
optimizer, lambda ep: 1.0 if ep < 3 else 0.7 ** (ep - 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoints and checkpoints["optimizer"]:
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
if checkpoints and checkpoints["scheduler"]:
|
||||||
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
@ -715,7 +720,9 @@ def run(rank, world_size, args):
|
|||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
# LR scheduler can hold multiple learning rates for multiple parameter groups;
|
||||||
|
# For now we report just the first LR which we assume concerns most of the parameters.
|
||||||
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
@ -738,10 +745,13 @@ def run(rank, world_size, args):
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user