From 05906e065b62fffe4ede195a531e8f8c9a18f1c3 Mon Sep 17 00:00:00 2001 From: Peter Ross Date: Thu, 8 Jun 2023 17:03:30 +1000 Subject: [PATCH] pruned_transducer_stateless7: use symlinks (when possible) to output best epochs --- .../ASR/pruned_transducer_stateless7/train.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 5ec71ec3f..c08da6733 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -50,7 +50,6 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 @@ -89,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + symlink_or_copyfile, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -601,7 +601,8 @@ def save_checkpoint( """ if rank != 0: return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + epoch_basename = f"epoch-{params.cur_epoch}.pt" + filename = params.exp_dir / epoch_basename save_checkpoint_impl( filename=filename, model=model, @@ -615,12 +616,14 @@ def save_checkpoint( ) if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) + symlink_or_copyfile( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt" + ) if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + symlink_or_copyfile( + exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt" + ) def compute_loss(