mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Expose scheduler parameters
This commit is contained in:
parent
8d269156a0
commit
274a1ef45f
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -251,6 +252,10 @@ def get_params() -> AttributeDict:
|
||||
"reset_interval": 2000,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
# parameters for new_bob scheduler
|
||||
"annealing_factor": 0.5,
|
||||
"threshold": 0.0025,
|
||||
"patient": 0,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -663,7 +668,13 @@ def run(rank, world_size, args):
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
scheduler = NewBobScheduler(optimizer)
|
||||
|
||||
scheduler = NewBobScheduler(
|
||||
optimizer,
|
||||
annealing_factor=params.annealing_factor,
|
||||
threshold=params.threshold,
|
||||
patient=params.patient,
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
logging.info("Load optimizer state_dict from checkpoint")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user