mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix
This commit is contained in:
parent
c526e958e5
commit
716a0a0da3
@ -165,7 +165,8 @@ class NewBobScheduler(LRScheduler):
|
|||||||
) / self.prev_metric
|
) / self.prev_metric
|
||||||
if improvement < self.threshold:
|
if improvement < self.threshold:
|
||||||
if self.current_patient == 0:
|
if self.current_patient == 0:
|
||||||
self.base_lrs *= self.annealing_factor
|
self.base_lrs = [x * self.annealing_factor for x in self.base_lrs]
|
||||||
|
self.patient *= 2
|
||||||
self.current_patient = self.patient
|
self.current_patient = self.patient
|
||||||
else:
|
else:
|
||||||
self.current_patient -= 1
|
self.current_patient -= 1
|
||||||
@ -180,4 +181,5 @@ class NewBobScheduler(LRScheduler):
|
|||||||
"prev_metric": self.prev_metric,
|
"prev_metric": self.prev_metric,
|
||||||
"current_metric": self.current_metric,
|
"current_metric": self.current_metric,
|
||||||
"current_patient": self.current_patient,
|
"current_patient": self.current_patient,
|
||||||
|
"patient": self.patient,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -135,21 +135,21 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-data",
|
"--lm-data",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lm_training_bpe_500/sorted_lm_data.pt",
|
default="data/lm_training_bpe_5000/sorted_lm_data.pt",
|
||||||
help="LM training data",
|
help="LM training data",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-data-valid",
|
"--lm-data-valid",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
|
default="data/lm_training_bpe_5000/sorted_lm_data-valid.pt",
|
||||||
help="LM validation data",
|
help="LM validation data",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=500,
|
default=5000,
|
||||||
help="Vocabulary size of the model",
|
help="Vocabulary size of the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for new_bob scheduler
|
# parameters for new_bob scheduler
|
||||||
"annealing_factor": 0.5,
|
"annealing_factor": 0.5,
|
||||||
"threshold": 0.0025,
|
"threshold": 0.0025,
|
||||||
"patient": 0,
|
"patient": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -509,7 +509,6 @@ def train_one_epoch(
|
|||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
scheduler.step_batch(loss)
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
@ -545,6 +544,7 @@ def train_one_epoch(
|
|||||||
# Note: "frames" here means "num_tokens"
|
# Note: "frames" here means "num_tokens"
|
||||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||||
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
|
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
|
||||||
|
scheduler.step_batch(tot_ppl)
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user