diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index af7b29096..5fc9a825a 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -251,6 +252,10 @@ def get_params() -> AttributeDict: "reset_interval": 2000, "valid_interval": 5000, "env_info": get_env_info(), + # parameters for new_bob scheduler + "annealing_factor": 0.5, + "threshold": 0.0025, + "patient": 0, } ) return params @@ -662,8 +667,14 @@ def run(rank, world_size, args): model.parameters(), lr=params.lr, weight_decay=params.weight_decay, + ) + + scheduler = NewBobScheduler( + optimizer, + annealing_factor=params.annealing_factor, + threshold=params.threshold, + patient=params.patient, ) - scheduler = NewBobScheduler(optimizer) if checkpoints: logging.info("Load optimizer state_dict from checkpoint")