Fix exception in find_checkpoints (#668)

This commit is contained in:
abb128 2022-11-26 04:10:37 +02:00 committed by GitHub
parent db75627e92
commit 61032e70e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -292,7 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
""" """
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt") pattern = re.compile(r"checkpoint-([0-9]+).pt")
iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints] iter_checkpoints = []
for c in checkpoints:
result = pattern.search(c)
if not result:
logging.warn(f"Invalid checkpoint filename {c}")
continue
iter_checkpoints.append((int(result.group(1)), c))
# iter_checkpoints is a list of tuples. Each tuple contains # iter_checkpoints is a list of tuples. Each tuple contains
# two elements: (iteration_number, checkpoint-iteration_number.pt) # two elements: (iteration_number, checkpoint-iteration_number.pt)