mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
remove checkpoint save after validation
This commit is contained in:
parent
c078772e59
commit
2793ccdf56
@ -500,10 +500,10 @@ def compute_loss(
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
features = batch["inputs"]
|
||||
assert features.ndim == 3
|
||||
if params.use_fp16:
|
||||
feature = feature.half()
|
||||
features = features.half()
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"]
|
||||
@ -526,7 +526,7 @@ def compute_loss(
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
model_outputs, acc = model(
|
||||
fbank=feature.to(device),
|
||||
fbank=features.to(device),
|
||||
fbank_lens=feature_lens.to(device),
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
@ -663,30 +663,6 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
if batch_idx != 0:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
|
Loading…
x
Reference in New Issue
Block a user