mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support decoding with averaged model when using --iter (#353)
* support decoding with averaged model when using --iter * minor fix * monir fix of copyright date
This commit is contained in:
parent
f783e10dc8
commit
20f092e709
@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
@ -540,7 +540,36 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
assert params.iter == 0 and params.avg > 0
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg + 1]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
assert start >= 1
|
assert start >= 1
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
|
|||||||
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
||||||
* weight_end
|
* weight_end
|
||||||
|
|
||||||
The model index could be epoch number or checkpoint number.
|
The model index could be epoch number or iteration number.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename_start:
|
filename_start:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user