mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Load num_tokens_seen from disk on checkpoint load.
This commit is contained in:
parent
b3b3e5daa0
commit
74bf02bba6
@ -474,6 +474,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
"batch_idx_train",
|
"batch_idx_train",
|
||||||
|
"num_tokens_seen",
|
||||||
]
|
]
|
||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
@ -647,7 +648,6 @@ def train(
|
|||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
batch_idx_offset: int = 0,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model until we have trained on the specified --num-tokens.
|
"""Train the model until we have trained on the specified --num-tokens.
|
||||||
|
|
||||||
@ -697,11 +697,10 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
for batch_idx_, batch in enumerate(train_dl):
|
for batch_idx_, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
|
||||||
|
|
||||||
if params.batch_idx_train % 10 == 0:
|
if params.batch_idx_train % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
|
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user