From 716a0a0da3d9c6048bdd72d00ae9661c64af02c1 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Sat, 3 Jun 2023 00:29:33 +0800 Subject: [PATCH] Fix --- icefall/rnn_lm/optim.py | 4 +++- icefall/rnn_lm/train.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/icefall/rnn_lm/optim.py b/icefall/rnn_lm/optim.py index 21328c7bb..f62bb4a80 100644 --- a/icefall/rnn_lm/optim.py +++ b/icefall/rnn_lm/optim.py @@ -165,7 +165,8 @@ class NewBobScheduler(LRScheduler): ) / self.prev_metric if improvement < self.threshold: if self.current_patient == 0: - self.base_lrs *= self.annealing_factor + self.base_lrs = [x * self.annealing_factor for x in self.base_lrs] + self.patient *= 2 self.current_patient = self.patient else: self.current_patient -= 1 @@ -180,4 +181,5 @@ class NewBobScheduler(LRScheduler): "prev_metric": self.prev_metric, "current_metric": self.current_metric, "current_patient": self.current_patient, + "patient": self.patient, } diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 5568e79ce..8e07eddc6 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -135,21 +135,21 @@ def get_parser(): parser.add_argument( "--lm-data", type=str, - default="data/lm_training_bpe_500/sorted_lm_data.pt", + default="data/lm_training_bpe_5000/sorted_lm_data.pt", help="LM training data", ) parser.add_argument( "--lm-data-valid", type=str, - default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", + default="data/lm_training_bpe_5000/sorted_lm_data-valid.pt", help="LM validation data", ) parser.add_argument( "--vocab-size", type=int, - default=500, + default=5000, help="Vocabulary size of the model", ) @@ -255,7 +255,7 @@ def get_params() -> AttributeDict: # parameters for new_bob scheduler "annealing_factor": 0.5, "threshold": 0.0025, - "patient": 0, + "patient": 10, } ) return params @@ -509,7 +509,6 @@ def train_one_epoch( # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - scheduler.step_batch(loss) optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) @@ -545,6 +544,7 @@ def train_one_epoch( # Note: "frames" here means "num_tokens" this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + scheduler.step_batch(tot_ppl) cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, "