pruned_transducer_stateless7: use symlinks (when possible) to output best epochs

This commit is contained in:
Peter Ross 2023-06-08 17:03:30 +10:00
parent 6c6ae63821
commit 05906e065b

View File

@ -50,7 +50,6 @@ import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
@ -89,6 +88,7 @@ from icefall.utils import (
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
symlink_or_copyfile,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -601,7 +601,8 @@ def save_checkpoint(
""" """
if rank != 0: if rank != 0:
return return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" epoch_basename = f"epoch-{params.cur_epoch}.pt"
filename = params.exp_dir / epoch_basename
save_checkpoint_impl( save_checkpoint_impl(
filename=filename, filename=filename,
model=model, model=model,
@ -615,12 +616,14 @@ def save_checkpoint(
) )
if params.best_train_epoch == params.cur_epoch: if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt" symlink_or_copyfile(
copyfile(src=filename, dst=best_train_filename) exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
)
if params.best_valid_epoch == params.cur_epoch: if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt" symlink_or_copyfile(
copyfile(src=filename, dst=best_valid_filename) exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
)
def compute_loss( def compute_loss(