mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
pruned_transducer_stateless7: use symlinks (when possible) to output best epochs
This commit is contained in:
parent
6c6ae63821
commit
05906e065b
@ -50,7 +50,6 @@ import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
@ -89,6 +88,7 @@ from icefall.utils import (
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
symlink_or_copyfile,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -601,7 +601,8 @@ def save_checkpoint(
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
epoch_basename = f"epoch-{params.cur_epoch}.pt"
|
||||
filename = params.exp_dir / epoch_basename
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
@ -615,12 +616,14 @@ def save_checkpoint(
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
symlink_or_copyfile(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-train-loss.pt"
|
||||
)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
symlink_or_copyfile(
|
||||
exp_dir=params.exp_dir, src=epoch_basename, dst="best-valid-loss.pt"
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user