From aea8a03e009122f371f2c696c7baf228276ff076 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 2 May 2022 12:17:43 +0800 Subject: [PATCH] update decode file --- .../pruned_transducer_stateless3/decode.py | 29 +++++++++---------- icefall/checkpoint.py | 14 ++------- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 34125e9d6..a6fe0336c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -536,22 +536,21 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: assert params.iter == 0 - if True: - start = params.epoch - params.avg - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"averaging modes over range with {filename_start} (excluded) " - f"and {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) + start = params.epoch - params.avg + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"averaging modes over range with {filename_start} (excluded) " + f"and {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, ) + ) model.to(device) model.eval() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index af8c1701d..c7b09c8ac 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model( weight_start = -batch_idx_train_start / interval weight_end = batch_idx_train_end / interval - avg = state_dict_end["model_avg"] + model_end = state_dict_end["model_avg"] model_start = state_dict_start["model_avg"] + avg = model_end # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model( avg[k] += model_start[k] * weight_start return avg - - -def load_checkpoint_with_averaged_model( - filename: str, - model: nn.Module, - strict: bool = True, -) -> None: - """Load checkpoint with aaveraged model.""" - logging.info(f"Loading checkpoint from {filename}, using averaged model") - checkpoint = torch.load(filename, map_location="cpu") - model.load_state_dict(checkpoint["model_avg"], strict=strict)