mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
restore checkpoint save after validation
This commit is contained in:
parent
06667e1f6d
commit
62dfe56cbe
@ -685,6 +685,30 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
if batch_idx != 0:
|
||||||
|
model.save_checkpoint(
|
||||||
|
save_dir=params.exp_dir,
|
||||||
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
client_state={},
|
||||||
|
exclude_frozen_parameters=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
|
params.exp_dir,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
exclude_frozen_parameters=True,
|
||||||
|
)
|
||||||
|
# save sampler state dict into checkpoint
|
||||||
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
|
torch.save(
|
||||||
|
sampler_state_dict,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||||
|
)
|
||||||
|
os.system(
|
||||||
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
shave_rate = params.shave_rate
|
shave_rate = params.shave_rate
|
||||||
while True:
|
while True:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user