add evaluation script

This commit is contained in:
marcoyang 2023-12-19 17:20:49 +08:00
parent 57ff00de6a
commit bd01c21200
2 changed files with 345 additions and 1 deletions

View File

@ -0,0 +1,344 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/evaluate.py \
--num-epochs 50 \
--start-epoch 10 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
"""
import argparse
import csv
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
from at_datamodule import AudioSetATDatamodule
from lhotse import load_manifest
from sklearn.metrics import average_precision_score
from train import add_model_arguments, get_model, get_params, str2multihot
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
add_model_arguments(parser)
return parser
def inference_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
):
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
audio_event = supervisions["audio_event"]
label, orig_labels = str2multihot(audio_event)
label = label.detach().cpu()
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
# convert to probabilities between 0-1
audio_logits = audio_logits.sigmoid().detach().cpu()
return audio_logits, label
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> Dict:
num_cuts = 0
embedding_dict = {}
teacher_embedding_dict = {}
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
all_logits = []
all_labels = []
for batch_idx, batch in enumerate(dl):
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
num_cuts += len(cut_ids)
audio_logits, labels = inference_one_batch(
params=params,
model=model,
batch=batch,
)
all_logits.append(audio_logits)
all_labels.append(labels)
if batch_idx % 20 == 1:
logging.info(f"Processed {num_cuts} cuts already.")
logging.info("Finish collecting audio logits")
return all_logits, all_labels
@torch.no_grad()
def main():
parser = get_parser()
AudioSetATDatamodule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "inference_audio_tagging"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Evaluation started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
args.return_cuts = True
audioset = AudioSetATDatamodule(args)
audioset_cuts = audioset.audioset_eval_cuts()
audioset_dl = audioset.valid_dataloaders(audioset_cuts)
test_sets = ["audioset_eval"]
logits, labels = decode_dataset(
dl=audioset_dl,
params=params,
model=model,
)
logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy()
labels = torch.cat(labels, dim=0).long().detach().numpy()
# compute the metric
mAP = average_precision_score(
y_true=labels,
y_score=logits,
)
logging.info(f"mAP for audioset eval is: {mAP}")
logging.info("Done")
if __name__ == "__main__":
main()

View File

@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--full-libri 1 \
--audioset-subset full \
--max-duration 1000