Expose scheduler parameters

This commit is contained in:
Yifan Yang 2023-06-02 16:59:29 +08:00
parent 8d269156a0
commit 274a1ef45f

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -251,6 +252,10 @@ def get_params() -> AttributeDict:
"reset_interval": 2000,
"valid_interval": 5000,
"env_info": get_env_info(),
# parameters for new_bob scheduler
"annealing_factor": 0.5,
"threshold": 0.0025,
"patient": 0,
}
)
return params
@ -663,7 +668,13 @@ def run(rank, world_size, args):
lr=params.lr,
weight_decay=params.weight_decay,
)
scheduler = NewBobScheduler(optimizer)
scheduler = NewBobScheduler(
optimizer,
annealing_factor=params.annealing_factor,
threshold=params.threshold,
patient=params.patient,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")