This commit is contained in:
Yifan Yang 2023-06-03 00:29:33 +08:00
parent c526e958e5
commit 716a0a0da3
2 changed files with 8 additions and 6 deletions

View File

@ -165,7 +165,8 @@ class NewBobScheduler(LRScheduler):
) / self.prev_metric
if improvement < self.threshold:
if self.current_patient == 0:
self.base_lrs *= self.annealing_factor
self.base_lrs = [x * self.annealing_factor for x in self.base_lrs]
self.patient *= 2
self.current_patient = self.patient
else:
self.current_patient -= 1
@ -180,4 +181,5 @@ class NewBobScheduler(LRScheduler):
"prev_metric": self.prev_metric,
"current_metric": self.current_metric,
"current_patient": self.current_patient,
"patient": self.patient,
}

View File

@ -135,21 +135,21 @@ def get_parser():
parser.add_argument(
"--lm-data",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt",
default="data/lm_training_bpe_5000/sorted_lm_data.pt",
help="LM training data",
)
parser.add_argument(
"--lm-data-valid",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
default="data/lm_training_bpe_5000/sorted_lm_data-valid.pt",
help="LM validation data",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
default=5000,
help="Vocabulary size of the model",
)
@ -255,7 +255,7 @@ def get_params() -> AttributeDict:
# parameters for new_bob scheduler
"annealing_factor": 0.5,
"threshold": 0.0025,
"patient": 0,
"patient": 10,
}
)
return params
@ -509,7 +509,6 @@ def train_one_epoch(
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
scheduler.step_batch(loss)
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
@ -545,6 +544,7 @@ def train_one_epoch(
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
scheduler.step_batch(tot_ppl)
cur_lr = scheduler.get_last_lr()[0]
logging.info(
f"Epoch {params.cur_epoch}, "