diff --git a/egs/libriheavy/LM/zipformer1/evaluate.py b/egs/libriheavy/LM/zipformer1/evaluate.py new file mode 100644 index 000000000..4f51d5d74 --- /dev/null +++ b/egs/libriheavy/LM/zipformer1/evaluate.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from train import get_model, get_params, add_model_arguments +from typing import Tuple +import torch + +from lm_datamodule import LmDataset + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) + +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + +def evaluate_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: torch.nn.Module, +) -> Tuple[float, float]: + """Compute the validation loss on a given validation set + + Args: + dl (torch.utils.data.DataLoader): PyTorch's dataloader containing the dataset + params (AttributeDict): It is returned by :func:`get_params`. + model (nn.Module): The neural model + """ + tot_loss = 0 + tot_frames = 0 + num_cuts = 0 + + log_interval = 50 + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + device = next(model.parameters()).device + + with torch.set_grad_enabled(False): + for batch_idx, batch in enumerate(dl): + + labels = batch.to(device) # (batch_size, sequence_length) + + loglikes = model(labels) + loss = -loglikes.sum() + + assert loss.requires_grad is False + + num_cuts += labels.size(0) + tot_loss += loss + tot_frames += labels.numel() + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return tot_loss.item(), tot_frames + + +def main(): + parser = get_parser() + + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "log-evaluation" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-validation-{params.suffix}") + logging.info(params) + logging.info("Evaluation started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + model = get_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + valid = LmDataset(params.valid_file_list, + bytes_per_segment=params.bytes_per_segment) + valid_dl = torch.utils.data.DataLoader( + dataset=valid, + batch_size=params.batch_size, + num_workers=params.num_workers, + drop_last=False) + + logging.info("Evaluation started!") + tot_loss, tot_frames = evaluate_dataset( + dl=valid_dl, + params=params, + model=model, + ) + + logging.info(f"Validation loss: {tot_loss/tot_frames} over {tot_frames} frames.") + logging.info("Finished!") + +if __name__ == "__main__": + main()