diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py new file mode 100644 index 000000000..3fb1c9c02 --- /dev/null +++ b/egs/audioset/AT/zipformer/evaluate.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + + +./zipformer/evaluate.py \ + --num-epochs 50 \ + --start-epoch 10 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + + +""" + +import argparse +import csv +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +import torch.nn.functional as F +from at_datamodule import AudioSetATDatamodule +from lhotse import load_manifest +from sklearn.metrics import average_precision_score +from train import add_model_arguments, get_model, get_params, str2multihot + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def inference_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +): + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + audio_event = supervisions["audio_event"] + + label, orig_labels = str2multihot(audio_event) + label = label.detach().cpu() + + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) + # convert to probabilities between 0-1 + audio_logits = audio_logits.sigmoid().detach().cpu() + + return audio_logits, label + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict: + num_cuts = 0 + embedding_dict = {} + teacher_embedding_dict = {} + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + all_logits = [] + all_labels = [] + + for batch_idx, batch in enumerate(dl): + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + num_cuts += len(cut_ids) + + audio_logits, labels = inference_one_batch( + params=params, + model=model, + batch=batch, + ) + + all_logits.append(audio_logits) + all_labels.append(labels) + + if batch_idx % 20 == 1: + logging.info(f"Processed {num_cuts} cuts already.") + logging.info("Finish collecting audio logits") + + return all_logits, all_labels + + +@torch.no_grad() +def main(): + parser = get_parser() + AudioSetATDatamodule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "inference_audio_tagging" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Evaluation started") + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info("About to create model") + + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + audioset = AudioSetATDatamodule(args) + + audioset_cuts = audioset.audioset_eval_cuts() + + audioset_dl = audioset.valid_dataloaders(audioset_cuts) + + test_sets = ["audioset_eval"] + + logits, labels = decode_dataset( + dl=audioset_dl, + params=params, + model=model, + ) + + logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy() + labels = torch.cat(labels, dim=0).long().detach().numpy() + + # compute the metric + mAP = average_precision_score( + y_true=labels, + y_score=logits, + ) + + logging.info(f"mAP for audioset eval is: {mAP}") + + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index e2c23b716..a6e6490ad 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ - --full-libri 1 \ + --audioset-subset full \ --max-duration 1000