From 20f092e7098f9809db1ea2ff25a37b17ff4f8237 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 7 May 2022 13:09:11 +0800 Subject: [PATCH] Support decoding with averaged model when using --iter (#353) * support decoding with averaged model when using --iter * minor fix * monir fix of copyright date --- .../pruned_transducer_stateless4/decode.py | 65 ++++++++++++++----- icefall/checkpoint.py | 6 +- 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 025ebd7bc..1f4a22213 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -540,23 +540,52 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) else: - assert params.iter == 0 and params.avg > 0 - start = params.epoch - params.avg - assert start >= 1 - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0 + start = params.epoch - params.avg + assert start >= 1 + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) ) - ) model.to(device) model.eval() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index ba3823ffc..170586455 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -1,5 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang, +# Zengwei Yao) # # See ../../LICENSE for clarification regarding multiple authors # @@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model( (3) avg = (model_end + model_start * (weight_start / weight_end)) * weight_end - The model index could be epoch number or checkpoint number. + The model index could be epoch number or iteration number. Args: filename_start: