From 74bf02bba6016c1eb37858a4e0e8a40f7d302bdb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Jun 2023 02:54:47 +0800 Subject: [PATCH] Load num_tokens_seen from disk on checkpoint load. --- egs/libriheavy/LM/zipformer1/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 5f38c6ce4..63519c1a3 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -474,6 +474,7 @@ def load_checkpoint_if_available( keys = [ "batch_idx_train", + "num_tokens_seen", ] for k in keys: params[k] = saved_params[k] @@ -647,7 +648,6 @@ def train( tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, - batch_idx_offset: int = 0, ) -> None: """Train the model until we have trained on the specified --num-tokens. @@ -697,11 +697,10 @@ def train( for batch_idx_, batch in enumerate(train_dl): - params.batch_idx_train += 1 - if params.batch_idx_train % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) + params.batch_idx_train += 1 try: with torch.cuda.amp.autocast(enabled=params.use_fp16):