#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: export CUDA_VISIBLE_DEVICES="0" ./zipformer/evaluate.py \ --epoch 50 \ --avg 10 \ --exp-dir zipformer/exp \ --max-duration 1000 """ import argparse import csv import logging import math import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 import numpy as np import sentencepiece as spm import torch import torch.nn as nn import torch.nn.functional as F from at_datamodule import AudioSetATDatamodule from lhotse import load_manifest try: from sklearn.metrics import average_precision_score except Exception as ex: raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn") from train import add_model_arguments, get_model, get_params, str2multihot from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, make_pad_mask, setup_logger, store_transcripts, str2bool, write_error_stats, ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=30, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( "--iter", type=int, default=0, help="""If positive, --epoch is ignored and it will use the checkpoint exp_dir/checkpoint-iter.pt. You can specify --avg to use more checkpoints for model averaging. """, ) parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." "Actually only the models with epoch number of `epoch-avg` and " "`epoch` are loaded for averaging. ", ) parser.add_argument( "--exp-dir", type=str, default="zipformer/exp", help="The experiment dir", ) add_model_arguments(parser) return parser def inference_one_batch( params: AttributeDict, model: nn.Module, batch: dict, ): device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3, feature.shape feature = feature.to(device) # at entry, feature is (N, T, C) supervisions = batch["supervisions"] audio_event = supervisions["audio_event"] label, _ = str2multihot(audio_event) label = label.detach().cpu() feature_lens = supervisions["num_frames"].to(device) encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) # convert to probabilities between 0-1 audio_logits = audio_logits.sigmoid().detach().cpu() return audio_logits, label def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, ) -> Dict: num_cuts = 0 try: num_batches = len(dl) except TypeError: num_batches = "?" all_logits = [] all_labels = [] for batch_idx, batch in enumerate(dl): cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] num_cuts += len(cut_ids) audio_logits, labels = inference_one_batch( params=params, model=model, batch=batch, ) all_logits.append(audio_logits) all_labels.append(labels) if batch_idx % 20 == 1: logging.info(f"Processed {num_cuts} cuts already.") logging.info("Finish collecting audio logits") return all_logits, all_labels @torch.no_grad() def main(): parser = get_parser() AudioSetATDatamodule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) params = get_params() params.update(vars(args)) params.res_dir = params.exp_dir / "inference_audio_tagging" if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.use_averaged_model: params.suffix += "-use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Evaluation started") logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) logging.info("About to create model") model = get_model(params) if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict( average_checkpoints(filenames, device=device), strict=False ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): if i >= 1: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict( average_checkpoints(filenames, device=device), strict=False ) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg + 1 ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) filename_start = filenames[-1] filename_end = filenames[0] logging.info( "Calculating the averaged model over iteration checkpoints" f" from {filename_start} (excluded) to {filename_end}" ) model.to(device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=device, ), strict=False, ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg assert start >= 1, start filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=device, ), strict=False, ) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") args.return_cuts = True audioset = AudioSetATDatamodule(args) audioset_cuts = audioset.audioset_eval_cuts() audioset_dl = audioset.valid_dataloaders(audioset_cuts) test_sets = ["audioset_eval"] logits, labels = decode_dataset( dl=audioset_dl, params=params, model=model, ) logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy() labels = torch.cat(labels, dim=0).long().detach().numpy() # compute the metric mAP = average_precision_score( y_true=labels, y_score=logits, ) logging.info(f"mAP for audioset eval is: {mAP}") logging.info("Done") if __name__ == "__main__": main()