diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 5ec71ec3f..2b4d51089 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -50,7 +50,6 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 @@ -89,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + symlink_or_copy, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -601,7 +601,8 @@ def save_checkpoint( """ if rank != 0: return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + epoch_basename = f"epoch-{params.cur_epoch}.pt" + filename = params.exp_dir / epoch_basename save_checkpoint_impl( filename=filename, model=model, @@ -615,12 +616,14 @@ def save_checkpoint( ) if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) + symlink_or_copy( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt" + ) if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + symlink_or_copy( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt" + ) def compute_loss( diff --git a/icefall/utils.py b/icefall/utils.py index d002982ec..dfe9a7b42 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path +from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 @@ -1881,3 +1882,20 @@ def is_cjk(character): ] ] ) + + +def symlink_or_copy(exp_dir: Path, src: str, dst: str): + """ + In the experiment directory, create a symlink pointing to src named dst. + If symlink creation fails (Windows?), fall back to copyfile.""" + + dir_fd = os.open(exp_dir, os.O_RDONLY) + try: + os.remove(dst, dir_fd=dir_fd) + except FileNotFoundError: + pass + try: + os.symlink(src=src, dst=dst, dir_fd=dir_fd) + except OSError: + copyfile(src=exp_dir / src, dst=exp_dir / dst) + os.close(dir_fd)