diff --git a/egs/imagenet/CLS/swin_transformer/validate.py b/egs/imagenet/CLS/swin_transformer/validate.py new file mode 100755 index 000000000..3f86d7d85 --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/validate.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import time +from pathlib import Path + +import torch +import torch.nn as nn + +from cls_datamodule import ImageNetClsDataModule +from train import add_model_arguments, get_params, get_model +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) +from utils import AverageMeter, accuracy + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def validate( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, +) -> None: + """Run the validation process.""" + batch_time = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for batch_idx, (images, targets) in enumerate(valid_dl): + images = images.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + # compute outputs + outputs = model(images) + + # measure accuracy and record loss + acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) + + acc1_meter.update(acc1.item(), targets.size(0)) + acc5_meter.update(acc5.item(), targets.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") + + +@torch.no_grad() +def main(): + parser = get_parser() + ImageNetClsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.exp_dir}/log-decode-{params.suffix}") + logging.info("Validation started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # Create datasets and dataloaders + imagenet = ImageNetClsDataModule(params) + valid_dl = imagenet.build_val_loader() + + validate( + params=params, + model=model, + valid_dl=valid_dl, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()