Use symlinks for best epochs (#1123)

* utils: add symlink_or_copyfile

* pruned_transducer_stateless7: use symlinks (when possible) to output best epochs

* Rename function

---------

Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
This commit is contained in:
Peter Ross 2023-06-12 15:51:46 +10:00 committed by GitHub
parent dca21c2a17
commit b4c38d7547
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 6 deletions

View File

@ -50,7 +50,6 @@ import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
@ -89,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch,
setup_logger,
str2bool,
symlink_or_copy,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -601,7 +601,8 @@ def save_checkpoint(
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
epoch_basename = f"epoch-{params.cur_epoch}.pt"
filename = params.exp_dir / epoch_basename
save_checkpoint_impl(
filename=filename,
model=model,
@ -615,12 +616,14 @@ def save_checkpoint(
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
symlink_or_copy(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
symlink_or_copy(
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
)
def compute_loss(

View File

@ -28,6 +28,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from shutil import copyfile
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import k2
@ -1881,3 +1882,20 @@ def is_cjk(character):
]
]
)
def symlink_or_copy(exp_dir: Path, src: str, dst: str):
"""
In the experiment directory, create a symlink pointing to src named dst.
If symlink creation fails (Windows?), fall back to copyfile."""
dir_fd = os.open(exp_dir, os.O_RDONLY)
try:
os.remove(dst, dir_fd=dir_fd)
except FileNotFoundError:
pass
try:
os.symlink(src=src, dst=dst, dir_fd=dir_fd)
except OSError:
copyfile(src=exp_dir / src, dst=exp_dir / dst)
os.close(dir_fd)