Load num_tokens_seen from disk on checkpoint load.

This commit is contained in:
Daniel Povey 2023-06-20 02:54:47 +08:00
parent b3b3e5daa0
commit 74bf02bba6

View File

@ -474,6 +474,7 @@ def load_checkpoint_if_available(
keys = [
"batch_idx_train",
"num_tokens_seen",
]
for k in keys:
params[k] = saved_params[k]
@ -647,7 +648,6 @@ def train(
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
batch_idx_offset: int = 0,
) -> None:
"""Train the model until we have trained on the specified --num-tokens.
@ -697,11 +697,10 @@ def train(
for batch_idx_, batch in enumerate(train_dl):
params.batch_idx_train += 1
if params.batch_idx_train % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):