diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index a6fe0336c..e868878e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -18,16 +18,16 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ @@ -35,8 +35,8 @@ Usage: --beam-size 4 (3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ @@ -44,8 +44,8 @@ Usage: --beam-size 4 (4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 1500 \ @@ -99,9 +99,9 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=28, + default=30, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -128,13 +128,17 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model", + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless4/exp", help="The experiment dir", ) @@ -529,19 +533,20 @@ def main(): start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): - if start >= 0: + if i >= 1: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) else: - assert params.iter == 0 + assert params.iter == 0 and params.avg > 0 start = params.epoch - params.avg + assert start >= 1 filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"averaging modes over range with {filename_start} (excluded) " - f"and {filename_end}" + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" ) model.to(device) model.load_state_dict(