mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
support decoding with averaged model when using --iter
This commit is contained in:
parent
ecfb3e9c26
commit
cdd3933ce1
@ -540,23 +540,52 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
assert params.iter == 0 and params.avg > 0
|
if params.iter > 0:
|
||||||
start = params.epoch - params.avg
|
filenames = find_checkpoints(
|
||||||
assert start >= 1
|
params.exp_dir, iteration=-params.iter
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
)[: params.avg + 1]
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
if len(filenames) == 0:
|
||||||
logging.info(
|
raise ValueError(
|
||||||
f"Calculating the averaged model over epoch range from "
|
f"No checkpoints found for"
|
||||||
f"{start} (excluded) to {params.epoch}"
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
elif len(filenames) < params.avg + 1:
|
||||||
model.load_state_dict(
|
raise ValueError(
|
||||||
average_checkpoints_with_averaged_model(
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
filename_start=filename_start,
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
filename_end=filename_end,
|
)
|
||||||
device=device,
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
|
|||||||
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
||||||
* weight_end
|
* weight_end
|
||||||
|
|
||||||
The model index could be epoch number or checkpoint number.
|
The model index could be epoch number or iteration number.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename_start:
|
filename_start:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user