mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Use symlinks for best epochs (#1123)
* utils: add symlink_or_copyfile * pruned_transducer_stateless7: use symlinks (when possible) to output best epochs * Rename function --------- Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
This commit is contained in:
parent
dca21c2a17
commit
b4c38d7547
@ -50,7 +50,6 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -89,6 +88,7 @@ from icefall.utils import (
|
|||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
symlink_or_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -601,7 +601,8 @@ def save_checkpoint(
|
|||||||
"""
|
"""
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
epoch_basename = f"epoch-{params.cur_epoch}.pt"
|
||||||
|
filename = params.exp_dir / epoch_basename
|
||||||
save_checkpoint_impl(
|
save_checkpoint_impl(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
model=model,
|
model=model,
|
||||||
@ -615,12 +616,14 @@ def save_checkpoint(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
symlink_or_copy(
|
||||||
copyfile(src=filename, dst=best_train_filename)
|
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
|
||||||
|
)
|
||||||
|
|
||||||
if params.best_valid_epoch == params.cur_epoch:
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
symlink_or_copy(
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
|
@ -28,6 +28,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -1881,3 +1882,20 @@ def is_cjk(character):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def symlink_or_copy(exp_dir: Path, src: str, dst: str):
|
||||||
|
"""
|
||||||
|
In the experiment directory, create a symlink pointing to src named dst.
|
||||||
|
If symlink creation fails (Windows?), fall back to copyfile."""
|
||||||
|
|
||||||
|
dir_fd = os.open(exp_dir, os.O_RDONLY)
|
||||||
|
try:
|
||||||
|
os.remove(dst, dir_fd=dir_fd)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
os.symlink(src=src, dst=dst, dir_fd=dir_fd)
|
||||||
|
except OSError:
|
||||||
|
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
||||||
|
os.close(dir_fd)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user