mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
remove checkpoint save after validation
This commit is contained in:
parent
c078772e59
commit
2793ccdf56
@ -500,10 +500,10 @@ def compute_loss(
|
|||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
feature = batch["inputs"]
|
features = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert features.ndim == 3
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
feature = feature.half()
|
features = features.half()
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"]
|
feature_lens = supervisions["num_frames"]
|
||||||
@ -526,7 +526,7 @@ def compute_loss(
|
|||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
model_outputs, acc = model(
|
model_outputs, acc = model(
|
||||||
fbank=feature.to(device),
|
fbank=features.to(device),
|
||||||
fbank_lens=feature_lens.to(device),
|
fbank_lens=feature_lens.to(device),
|
||||||
input_ids=input_ids.to(device),
|
input_ids=input_ids.to(device),
|
||||||
attention_mask=attention_mask.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
@ -663,30 +663,6 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
if batch_idx != 0:
|
|
||||||
model.save_checkpoint(
|
|
||||||
save_dir=params.exp_dir,
|
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
|
||||||
client_state={},
|
|
||||||
exclude_frozen_parameters=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
|
||||||
params.exp_dir,
|
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
|
||||||
exclude_frozen_parameters=True,
|
|
||||||
)
|
|
||||||
# save sampler state dict into checkpoint
|
|
||||||
sampler_state_dict = train_dl.sampler.state_dict()
|
|
||||||
torch.save(
|
|
||||||
sampler_state_dict,
|
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
|
||||||
)
|
|
||||||
os.system(
|
|
||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user