mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support specifying iteration number of checkpoints for decoding. (#289)
This commit is contained in:
parent
0b6a2213c3
commit
87cf9231ea
@ -98,27 +98,28 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
"Note: Epoch counts from 0.",
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch' and '--iter'",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg-last-n",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch and --avg are ignored and it
|
|
||||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
|
||||||
where xxx is the number of processed batches while
|
|
||||||
saving that checkpoint.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -453,13 +454,19 @@ def main():
|
|||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -485,8 +492,20 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.avg_last_n > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoints(out_dir: Path) -> List[str]:
|
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||||
"""Find all available checkpoints in a directory.
|
"""Find all available checkpoints in a directory.
|
||||||
|
|
||||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||||
where xxx is a numerical value.
|
where xxx is a numerical value.
|
||||||
|
|
||||||
|
Assume you have the following checkpoints in the folder `foo`:
|
||||||
|
|
||||||
|
- checkpoint-1.pt
|
||||||
|
- checkpoint-20.pt
|
||||||
|
- checkpoint-300.pt
|
||||||
|
- checkpoint-4000.pt
|
||||||
|
|
||||||
|
Case 1 (Return all checkpoints)::
|
||||||
|
|
||||||
|
find_checkpoints(out_dir='foo')
|
||||||
|
|
||||||
|
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
|
||||||
|
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
|
||||||
|
|
||||||
|
find_checkpoints(out_dir='foo', iteration=20)
|
||||||
|
|
||||||
|
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
|
||||||
|
checkpoint-20.pt, checkpoint-1.pt)::
|
||||||
|
|
||||||
|
find_checkpoints(out_dir='foo', iteration=-20)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
out_dir:
|
out_dir:
|
||||||
The directory where to search for checkpoints.
|
The directory where to search for checkpoints.
|
||||||
|
iteration:
|
||||||
|
If it is 0, return all available checkpoints.
|
||||||
|
If it is positive, return the checkpoints whose iteration number is
|
||||||
|
greater than or equal to `iteration`.
|
||||||
|
If it is negative, return the checkpoints whose iteration number is
|
||||||
|
less than or equal to `-iteration`.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list of checkpoint filenames, sorted in descending
|
Return a list of checkpoint filenames, sorted in descending
|
||||||
order by the numerical value in the filename.
|
order by the numerical value in the filename.
|
||||||
"""
|
"""
|
||||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||||
idx_checkpoints = [
|
iter_checkpoints = [
|
||||||
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||||
]
|
]
|
||||||
|
# iter_checkpoints is a list of tuples. Each tuple contains
|
||||||
|
# two elements: (iteration_number, checkpoint-iteration_number.pt)
|
||||||
|
|
||||||
|
iter_checkpoints = sorted(
|
||||||
|
iter_checkpoints, reverse=True, key=lambda x: x[0]
|
||||||
|
)
|
||||||
|
if iteration >= 0:
|
||||||
|
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
|
||||||
|
else:
|
||||||
|
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
|
||||||
|
|
||||||
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
|
||||||
ans = [ic[1] for ic in idx_checkpoints]
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user