mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
132 lines
3.9 KiB
Python
Executable File
132 lines
3.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
|
|
# Liyong Guo)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import logging
|
|
from typing import Dict, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from sklearn.metrics import roc_curve, auc
|
|
|
|
from icefall.utils import setup_logger
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--positive-score-file",
|
|
type=str,
|
|
required=True,
|
|
help="score file of positive data",
|
|
)
|
|
parser.add_argument(
|
|
"--negative-score-file",
|
|
type=str,
|
|
required=True,
|
|
help="score file of negative data",
|
|
)
|
|
parser.add_argument(
|
|
"--legend",
|
|
type=str,
|
|
required=True,
|
|
help="legend of ROC curve picture.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_score(score_file: Path) -> Dict[str, float]:
|
|
"""
|
|
Args:
|
|
score_file: Path to score file. Each line has two columns.
|
|
The first column is utt-id, and the second one is score.
|
|
This score could be viewed as probability of being wakeup word.
|
|
|
|
Returns:
|
|
A dict with that key is utt-id and value is corresponding score.
|
|
"""
|
|
pos_dict = {}
|
|
with open(score_file, "r", encoding="utf8") as fin:
|
|
for line in fin:
|
|
arr = line.strip().split()
|
|
assert len(arr) == 2
|
|
key = arr[0]
|
|
score = float(arr[1])
|
|
pos_dict[key] = score
|
|
return pos_dict
|
|
|
|
|
|
def get_roc_and_auc(
|
|
pos_dict: Dict,
|
|
neg_dict: Dict,
|
|
) -> Tuple[np.array, np.array, float]:
|
|
"""
|
|
Args:
|
|
pos_dict: scores of positive samples.
|
|
neg_dict: scores of negative samples.
|
|
Return:
|
|
A tuple of three elements, which will be used to plot ROC curve.
|
|
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
|
|
The third element is area under the ROC curve(AUC).
|
|
"""
|
|
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
|
|
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
|
|
|
|
pos_y = np.ones_like(pos_scores, dtype=int)
|
|
neg_y = np.zeros_like(neg_scores, dtype=int)
|
|
|
|
scores = np.concatenate([pos_scores, neg_scores])
|
|
y = np.concatenate([pos_y, neg_y])
|
|
|
|
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
return fpr, tpr, roc_auc
|
|
|
|
|
|
def main():
|
|
|
|
args = get_args()
|
|
|
|
score_dir = Path(args.positive_score_file).parent
|
|
setup_logger(f"{score_dir}/log/log-auc-{args.legend}")
|
|
logging.info(f"About to compute AUC of {args.legend}")
|
|
|
|
pos_dict = load_score(args.positive_score_file)
|
|
neg_dict = load_score(args.negative_score_file)
|
|
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
|
|
|
|
plt.figure(figsize=(16, 9))
|
|
plt.plot(fpr, tpr, label=f"{args.legend}(AUC = %1.8f)" % roc_auc)
|
|
|
|
plt.xlim([0.0, 1.0])
|
|
plt.ylim([0.0, 1.0])
|
|
plt.xlabel("False Positive Rate")
|
|
plt.ylabel("True Positive Rate")
|
|
plt.title("Receiver operating characteristic(ROC)")
|
|
plt.legend(loc="lower right")
|
|
|
|
output_path = Path(args.positive_score_file).parent
|
|
logging.info(f"AUC of {args.legend} {output_path}: {roc_auc}")
|
|
plt.savefig(f"{output_path}/{args.legend}.pdf", bbox_inches="tight")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|