mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Set the initial learning rate directly.
This commit is contained in:
parent
02e409b6ce
commit
58eb498219
@ -38,17 +38,14 @@ from icefall.utils import (
|
|||||||
|
|
||||||
def create_madam(
|
def create_madam(
|
||||||
params,
|
params,
|
||||||
model_size: int = 256,
|
lr: float = 5e-4,
|
||||||
factor: float = 2.0,
|
|
||||||
warm_step: int = 25000,
|
|
||||||
min_target_rms: float = 0.05,
|
min_target_rms: float = 0.05,
|
||||||
limit_grad_factor: float = float("inf"),
|
limit_grad_factor: float = float("inf"),
|
||||||
l2_period: int = 1,
|
l2_period: int = 1,
|
||||||
):
|
):
|
||||||
initial_lr = warm_step ** (-0.5)
|
|
||||||
optimizer = Madam(
|
optimizer = Madam(
|
||||||
params,
|
params,
|
||||||
lr=initial_lr,
|
lr=lr,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
eps=1e-9,
|
eps=1e-9,
|
||||||
min_target_rms=min_target_rms,
|
min_target_rms=min_target_rms,
|
||||||
@ -166,7 +163,6 @@ def get_params() -> AttributeDict:
|
|||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
#
|
#
|
||||||
"accum_grad": 1,
|
|
||||||
"att_rate": 0.7,
|
"att_rate": 0.7,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
@ -174,8 +170,7 @@ def get_params() -> AttributeDict:
|
|||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
"use_feat_batchnorm": False,
|
"use_feat_batchnorm": False,
|
||||||
"lr_factor": 2.0,
|
"lr": 5e-4,
|
||||||
"warm_step": 30000,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -718,12 +713,7 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = create_madam(
|
optimizer = create_madam(model.parameters(), lr=params.lr)
|
||||||
model.parameters(),
|
|
||||||
model_size=params.attention_dim,
|
|
||||||
factor=params.lr_factor,
|
|
||||||
warm_step=params.warm_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)
|
optimizer, lambda ep: 1.0 if ep < 3 else 0.75 ** (ep - 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user