mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Expose scheduler parameters
This commit is contained in:
parent
8d269156a0
commit
274a1ef45f
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Yifan Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -251,6 +252,10 @@ def get_params() -> AttributeDict:
|
|||||||
"reset_interval": 2000,
|
"reset_interval": 2000,
|
||||||
"valid_interval": 5000,
|
"valid_interval": 5000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
# parameters for new_bob scheduler
|
||||||
|
"annealing_factor": 0.5,
|
||||||
|
"threshold": 0.0025,
|
||||||
|
"patient": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -662,8 +667,14 @@ def run(rank, world_size, args):
|
|||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
weight_decay=params.weight_decay,
|
weight_decay=params.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = NewBobScheduler(
|
||||||
|
optimizer,
|
||||||
|
annealing_factor=params.annealing_factor,
|
||||||
|
threshold=params.threshold,
|
||||||
|
patient=params.patient,
|
||||||
)
|
)
|
||||||
scheduler = NewBobScheduler(optimizer)
|
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
logging.info("Load optimizer state_dict from checkpoint")
|
logging.info("Load optimizer state_dict from checkpoint")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user