This commit is contained in:
Yifan Yang 2023-06-03 00:29:33 +08:00
parent c526e958e5
commit 716a0a0da3
2 changed files with 8 additions and 6 deletions

View File

@ -165,7 +165,8 @@ class NewBobScheduler(LRScheduler):
) / self.prev_metric ) / self.prev_metric
if improvement < self.threshold: if improvement < self.threshold:
if self.current_patient == 0: if self.current_patient == 0:
self.base_lrs *= self.annealing_factor self.base_lrs = [x * self.annealing_factor for x in self.base_lrs]
self.patient *= 2
self.current_patient = self.patient self.current_patient = self.patient
else: else:
self.current_patient -= 1 self.current_patient -= 1
@ -180,4 +181,5 @@ class NewBobScheduler(LRScheduler):
"prev_metric": self.prev_metric, "prev_metric": self.prev_metric,
"current_metric": self.current_metric, "current_metric": self.current_metric,
"current_patient": self.current_patient, "current_patient": self.current_patient,
"patient": self.patient,
} }

View File

@ -135,21 +135,21 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lm-data", "--lm-data",
type=str, type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt", default="data/lm_training_bpe_5000/sorted_lm_data.pt",
help="LM training data", help="LM training data",
) )
parser.add_argument( parser.add_argument(
"--lm-data-valid", "--lm-data-valid",
type=str, type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", default="data/lm_training_bpe_5000/sorted_lm_data-valid.pt",
help="LM validation data", help="LM validation data",
) )
parser.add_argument( parser.add_argument(
"--vocab-size", "--vocab-size",
type=int, type=int,
default=500, default=5000,
help="Vocabulary size of the model", help="Vocabulary size of the model",
) )
@ -255,7 +255,7 @@ def get_params() -> AttributeDict:
# parameters for new_bob scheduler # parameters for new_bob scheduler
"annealing_factor": 0.5, "annealing_factor": 0.5,
"threshold": 0.0025, "threshold": 0.0025,
"patient": 0, "patient": 10,
} }
) )
return params return params
@ -509,7 +509,6 @@ def train_one_epoch(
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
scheduler.step_batch(loss)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
@ -545,6 +544,7 @@ def train_one_epoch(
# Note: "frames" here means "num_tokens" # Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
scheduler.step_batch(tot_ppl)
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "