pruned_transducer_stateless7: use symlinks (when possible) to output best epochs

This commit is contained in:
Peter Ross 2023-06-08 17:03:30 +10:00
parent 6c6ae63821
commit 05906e065b

View File

@ -50,7 +50,6 @@ import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
@ -89,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
symlink_or_copyfile,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -601,7 +601,8 @@ def save_checkpoint(
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
epoch_basename = f"epoch-{params.cur_epoch}.pt"
filename = params.exp_dir / epoch_basename
save_checkpoint_impl(
filename=filename,
model=model,
@ -615,12 +616,14 @@ def save_checkpoint(
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
symlink_or_copyfile(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
symlink_or_copyfile(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
)
def compute_loss(