Set the initial learning rate directly.

This commit is contained in:
Fangjun Kuang 2021-08-16 15:35:00 +08:00
parent 02e409b6ce
commit 58eb498219

View File

@ -38,17 +38,14 @@ from icefall.utils import (
def create_madam( def create_madam(
params, params,
model_size: int = 256, lr: float = 5e-4,
factor: float = 2.0,
warm_step: int = 25000,
min_target_rms: float = 0.05, min_target_rms: float = 0.05,
limit_grad_factor: float = float("inf"), limit_grad_factor: float = float("inf"),
l2_period: int = 1, l2_period: int = 1,
): ):
initial_lr = warm_step ** (-0.5)
optimizer = Madam( optimizer = Madam(
params, params,
lr=initial_lr, lr=lr,
betas=(0.9, 0.98), betas=(0.9, 0.98),
eps=1e-9, eps=1e-9,
min_target_rms=min_target_rms, min_target_rms=min_target_rms,
@ -166,7 +163,6 @@ def get_params() -> AttributeDict:
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
# #
"accum_grad": 1,
"att_rate": 0.7, "att_rate": 0.7,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
@ -174,8 +170,7 @@ def get_params() -> AttributeDict:
"is_espnet_structure": True, "is_espnet_structure": True,
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": False, "use_feat_batchnorm": False,
"lr_factor": 2.0, "lr": 5e-4,
"warm_step": 30000,
} }
) )
@ -718,12 +713,7 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = create_madam( optimizer = create_madam(model.parameters(), lr=params.lr)
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
scheduler = torch.optim.lr_scheduler.LambdaLR( scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2) optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)