From 08b37e07a4bf10d2098d1055acf4a3304e437b75 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 2 May 2022 00:50:32 +0800 Subject: [PATCH] minor fix --- .../ASR/pruned_transducer_stateless3/decode.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 016393215..34125e9d6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -81,7 +81,6 @@ from icefall.checkpoint import ( average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, - load_checkpoint_with_averaged_model, ) from icefall.utils import ( AttributeDict, @@ -481,6 +480,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -534,15 +536,14 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: assert params.iter == 0 - if params.avg == 1: - filename = f"{params.exp_dir}/epoch-{params.epoch}.pt" - load_checkpoint_with_averaged_model(filename, model) - else: - assert params.avg > 1 + if True: start = params.epoch - params.avg filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info(f"averaging {filename_start} and {filename_end}") + logging.info( + f"averaging modes over range with {filename_start} (excluded) " + f"and {filename_end}" + ) model.to(device) model.load_state_dict( average_checkpoints_with_averaged_model(