mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Replace warmup with lr scheduler.
This commit is contained in:
parent
0be42bef69
commit
02e409b6ce
File diff suppressed because it is too large
Load Diff
@ -17,8 +17,7 @@ import torch.nn as nn
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
|
||||||
# from transformer import Noam
|
from madam import Madam
|
||||||
from madam_no_warmup import Moam
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -37,6 +36,29 @@ from icefall.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_madam(
|
||||||
|
params,
|
||||||
|
model_size: int = 256,
|
||||||
|
factor: float = 2.0,
|
||||||
|
warm_step: int = 25000,
|
||||||
|
min_target_rms: float = 0.05,
|
||||||
|
limit_grad_factor: float = float("inf"),
|
||||||
|
l2_period: int = 1,
|
||||||
|
):
|
||||||
|
initial_lr = warm_step ** (-0.5)
|
||||||
|
optimizer = Madam(
|
||||||
|
params,
|
||||||
|
lr=initial_lr,
|
||||||
|
betas=(0.9, 0.98),
|
||||||
|
eps=1e-9,
|
||||||
|
min_target_rms=min_target_rms,
|
||||||
|
limit_grad_factor=limit_grad_factor,
|
||||||
|
l2_period=l2_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -86,8 +108,6 @@ def get_params() -> AttributeDict:
|
|||||||
- lang_dir: It contains language related input files such as
|
- lang_dir: It contains language related input files such as
|
||||||
"lexicon.txt"
|
"lexicon.txt"
|
||||||
|
|
||||||
- lr: It specifies the initial learning rate
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
in computing features.
|
in computing features.
|
||||||
|
|
||||||
@ -155,6 +175,7 @@ def get_params() -> AttributeDict:
|
|||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
"use_feat_batchnorm": False,
|
"use_feat_batchnorm": False,
|
||||||
"lr_factor": 2.0,
|
"lr_factor": 2.0,
|
||||||
|
"warm_step": 30000,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -697,14 +718,15 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = Moam(
|
optimizer = create_madam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
|
warm_step=params.warm_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer, lambda ep: 1.0 if ep < 3 else 0.7 ** (ep - 2)
|
optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints and checkpoints["optimizer"]:
|
if checkpoints and checkpoints["optimizer"]:
|
||||||
@ -720,8 +742,6 @@ def run(rank, world_size, args):
|
|||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
# LR scheduler can hold multiple learning rates for multiple parameter groups;
|
|
||||||
# For now we report just the first LR which we assume concerns most of the parameters.
|
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user