remove checkpoint save after validation

This commit is contained in:
yfyeung 2025-05-12 06:36:20 +00:00
parent c078772e59
commit 2793ccdf56

View File

@ -500,10 +500,10 @@ def compute_loss(
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
features = batch["inputs"]
assert features.ndim == 3
if params.use_fp16:
feature = feature.half()
features = features.half()
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"]
@ -526,7 +526,7 @@ def compute_loss(
with torch.set_grad_enabled(is_training):
model_outputs, acc = model(
fbank=feature.to(device),
fbank=features.to(device),
fbank_lens=feature_lens.to(device),
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
@ -663,30 +663,6 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
loss, loss_info = compute_loss(
params=params,