mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Load num_tokens_seen from disk on checkpoint load.
This commit is contained in:
parent
b3b3e5daa0
commit
74bf02bba6
@ -474,6 +474,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
keys = [
|
||||
"batch_idx_train",
|
||||
"num_tokens_seen",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
@ -647,7 +648,6 @@ def train(
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
batch_idx_offset: int = 0,
|
||||
) -> None:
|
||||
"""Train the model until we have trained on the specified --num-tokens.
|
||||
|
||||
@ -697,11 +697,10 @@ def train(
|
||||
|
||||
|
||||
for batch_idx_, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
||||
if params.batch_idx_train % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user