mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support specifying iteration number of checkpoints for decoding. (#289)
This commit is contained in:
parent
0b6a2213c3
commit
87cf9231ea
@ -98,27 +98,28 @@ def get_parser():
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -453,13 +454,19 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -485,8 +492,20 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoints(out_dir: Path) -> List[str]:
|
||||
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
"""Find all available checkpoints in a directory.
|
||||
|
||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||
where xxx is a numerical value.
|
||||
|
||||
Assume you have the following checkpoints in the folder `foo`:
|
||||
|
||||
- checkpoint-1.pt
|
||||
- checkpoint-20.pt
|
||||
- checkpoint-300.pt
|
||||
- checkpoint-4000.pt
|
||||
|
||||
Case 1 (Return all checkpoints)::
|
||||
|
||||
find_checkpoints(out_dir='foo')
|
||||
|
||||
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
|
||||
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
|
||||
|
||||
find_checkpoints(out_dir='foo', iteration=20)
|
||||
|
||||
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
|
||||
checkpoint-20.pt, checkpoint-1.pt)::
|
||||
|
||||
find_checkpoints(out_dir='foo', iteration=-20)
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory where to search for checkpoints.
|
||||
iteration:
|
||||
If it is 0, return all available checkpoints.
|
||||
If it is positive, return the checkpoints whose iteration number is
|
||||
greater than or equal to `iteration`.
|
||||
If it is negative, return the checkpoints whose iteration number is
|
||||
less than or equal to `-iteration`.
|
||||
Returns:
|
||||
Return a list of checkpoint filenames, sorted in descending
|
||||
order by the numerical value in the filename.
|
||||
"""
|
||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||
idx_checkpoints = [
|
||||
iter_checkpoints = [
|
||||
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||
]
|
||||
# iter_checkpoints is a list of tuples. Each tuple contains
|
||||
# two elements: (iteration_number, checkpoint-iteration_number.pt)
|
||||
|
||||
iter_checkpoints = sorted(
|
||||
iter_checkpoints, reverse=True, key=lambda x: x[0]
|
||||
)
|
||||
if iteration >= 0:
|
||||
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
|
||||
else:
|
||||
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
|
||||
|
||||
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
||||
ans = [ic[1] for ic in idx_checkpoints]
|
||||
return ans
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user