Support decoding with averaged model when using --iter (#353)

* support decoding with averaged model when using --iter

* minor fix

* monir fix of copyright date
This commit is contained in:
Zengwei Yao 2022-05-07 13:09:11 +08:00 committed by GitHub
parent f783e10dc8
commit 20f092e709
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 21 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
@ -540,7 +540,36 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
assert params.iter == 0 and params.avg > 0 if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1 assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
(3) avg = (model_end + model_start * (weight_start / weight_end)) (3) avg = (model_end + model_start * (weight_start / weight_end))
* weight_end * weight_end
The model index could be epoch number or checkpoint number. The model index could be epoch number or iteration number.
Args: Args:
filename_start: filename_start: