From b4c38d754745d0127ec728b4ab717c090777b6f3 Mon Sep 17 00:00:00 2001 From: Peter Ross Date: Mon, 12 Jun 2023 15:51:46 +1000 Subject: [PATCH] Use symlinks for best epochs (#1123) * utils: add symlink_or_copyfile * pruned_transducer_stateless7: use symlinks (when possible) to output best epochs * Rename function --------- Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com> --- .../ASR/pruned_transducer_stateless7/train.py | 15 +++++++++------ icefall/utils.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 5ec71ec3f..2b4d51089 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -50,7 +50,6 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 @@ -89,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + symlink_or_copy, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -601,7 +601,8 @@ def save_checkpoint( """ if rank != 0: return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + epoch_basename = f"epoch-{params.cur_epoch}.pt" + filename = params.exp_dir / epoch_basename save_checkpoint_impl( filename=filename, model=model, @@ -615,12 +616,14 @@ def save_checkpoint( ) if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) + symlink_or_copy( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt" + ) if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + symlink_or_copy( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt" + ) def compute_loss( diff --git a/icefall/utils.py b/icefall/utils.py index d002982ec..dfe9a7b42 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path +from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 @@ -1881,3 +1882,20 @@ def is_cjk(character): ] ] ) + + +def symlink_or_copy(exp_dir: Path, src: str, dst: str): + """ + In the experiment directory, create a symlink pointing to src named dst. + If symlink creation fails (Windows?), fall back to copyfile.""" + + dir_fd = os.open(exp_dir, os.O_RDONLY) + try: + os.remove(dst, dir_fd=dir_fd) + except FileNotFoundError: + pass + try: + os.symlink(src=src, dst=dst, dir_fd=dir_fd) + except OSError: + copyfile(src=exp_dir / src, dst=exp_dir / dst) + os.close(dir_fd)