From fab0258df59c7f300f711410c89e727cb5e344d2 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Wed, 31 May 2023 10:41:11 +0800 Subject: [PATCH] Add averaged model to rnnlm decoding --- icefall/rnn_lm/compute_perplexity.py | 141 ++++++++++++++++++++------- 1 file changed, 104 insertions(+), 37 deletions(-) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index cc566bd92..4eadafb27 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -20,8 +20,8 @@ Usage: ./rnn_lm/compute_perplexity.py \ --epoch 4 \ --avg 2 \ + --use-averaged-model 1 \ --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt - """ import argparse @@ -33,7 +33,12 @@ import torch from dataset import get_dataloader from model import RnnLmModel -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.utils import AttributeDict, setup_logger, str2bool @@ -69,6 +74,17 @@ def get_parser(): """, ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -122,14 +138,14 @@ def get_parser(): parser.add_argument( "--batch-size", type=int, - default=50, + default=150, help="Number of RNN layers the model", ) parser.add_argument( "--max-sent-len", type=int, - default=100, + default=200, help="Number of RNN layers the model", ) @@ -153,6 +169,7 @@ def get_parser(): default=0, help="Blank ID", ) + return parser @@ -165,13 +182,18 @@ def main(): params = AttributeDict(vars(args)) + if params.use_averaged_model: + params.suffix = "-use-averaged-model" + else: + params.suffix = "" + if params.iter > 0: setup_logger( - f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}" + f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}{params.suffix}" ) else: setup_logger( - f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}" + f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}{params.suffix}" ) logging.info("Computing perplexity started") logging.info(params) @@ -191,37 +213,82 @@ def main(): tie_weights=params.tie_weights, ) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() @@ -263,7 +330,7 @@ def main(): ppl = math.exp(tot_loss / num_tokens) logging.info( f"total nll: {tot_loss}, num tokens: {num_tokens}, " - f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + f"num sentences: {num_sentences}, ppl: {ppl:.3f}, " )