mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix
This commit is contained in:
parent
c526e958e5
commit
716a0a0da3
@ -165,7 +165,8 @@ class NewBobScheduler(LRScheduler):
|
||||
) / self.prev_metric
|
||||
if improvement < self.threshold:
|
||||
if self.current_patient == 0:
|
||||
self.base_lrs *= self.annealing_factor
|
||||
self.base_lrs = [x * self.annealing_factor for x in self.base_lrs]
|
||||
self.patient *= 2
|
||||
self.current_patient = self.patient
|
||||
else:
|
||||
self.current_patient -= 1
|
||||
@ -180,4 +181,5 @@ class NewBobScheduler(LRScheduler):
|
||||
"prev_metric": self.prev_metric,
|
||||
"current_metric": self.current_metric,
|
||||
"current_patient": self.current_patient,
|
||||
"patient": self.patient,
|
||||
}
|
||||
|
||||
@ -135,21 +135,21 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lm-data",
|
||||
type=str,
|
||||
default="data/lm_training_bpe_500/sorted_lm_data.pt",
|
||||
default="data/lm_training_bpe_5000/sorted_lm_data.pt",
|
||||
help="LM training data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-data-valid",
|
||||
type=str,
|
||||
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
|
||||
default="data/lm_training_bpe_5000/sorted_lm_data-valid.pt",
|
||||
help="LM validation data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=500,
|
||||
default=5000,
|
||||
help="Vocabulary size of the model",
|
||||
)
|
||||
|
||||
@ -255,7 +255,7 @@ def get_params() -> AttributeDict:
|
||||
# parameters for new_bob scheduler
|
||||
"annealing_factor": 0.5,
|
||||
"threshold": 0.0025,
|
||||
"patient": 0,
|
||||
"patient": 10,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -509,7 +509,6 @@ def train_one_epoch(
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
scheduler.step_batch(loss)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
@ -545,6 +544,7 @@ def train_one_epoch(
|
||||
# Note: "frames" here means "num_tokens"
|
||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
|
||||
scheduler.step_batch(tot_ppl)
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user