This commit is contained in:
glynpu 2023-03-16 12:37:46 +08:00
parent a49817385a
commit 39c0ae7749
2 changed files with 394 additions and 0 deletions

279
egs/himia/wuw/ctc_tdnn/decode.py Executable file
View File

@ -0,0 +1,279 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Tuple
import numpy as np
from lhotse.features.io import NumpyHdf5Reader
from tqdm import tqdm
from icefall.utils import AttributeDict
from train import get_params
from graph import ctc_trivial_decoding_graph
class Arc:
def __init__(
self, src_state: int, dst_state: int, ilabel: int, olabel: int
) -> None:
self.src_state = int(src_state)
self.dst_state = int(dst_state)
self.ilabel = int(ilabel)
self.olabel = int(olabel)
def next_state(self) -> None:
return self.dst_state
class State:
def __init__(self) -> None:
self.arc_list = list()
def add_arc(self, arc: Arc) -> None:
self.arc_list.append(arc)
class FiniteStateTransducer:
"""Represents a decoding graph for wake word detection."""
def __init__(self, graph: str) -> None:
self.state_list = list()
for arc_str in graph.split("\n"):
arc = arc_str.strip().split()
if len(arc) == 0:
continue
# 1 and 2 for final state
# 4 for non-final state
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
if len(arc) == 4: # Non-final state
# FST must be sorted
if len(self.state_list) <= int(arc[0]):
new_state = State()
self.state_list.append(new_state)
self.state_list[int(arc[0])].add_arc(
Arc(arc[0], arc[1], arc[2], arc[3])
)
else:
self.final_state_id = int(arc[0])
def to_str(self) -> None:
fst_str = ""
for state_idx in range(len(self.state_list)):
cur_state = self.state_list[state_idx]
for arc_idx in range(len(cur_state.arc_list)):
cur_arc = cur_state.arc_list[arc_idx]
ilabel = cur_arc.ilabel
olabel = cur_arc.olabel
src_state = cur_arc.src_state
dst_state = cur_arc.dst_state
fst_str += f"{src_state} {dst_state} {ilabel} {olabel}\n"
fst_str += f"{dst_state}\n"
return fst_str
class Token:
def __init__(self) -> None:
self.is_active = False
self.total_score = -float("inf")
self.keyword_frames = 0
self.average_keyword_score = -float("inf")
self.average_max_keyword_score = 0.0
def set_token(
self,
src_token,
is_keyword_ilabel: bool,
acoustic_score: float,
) -> None:
"""
A dynamic programming process computing the highest score for a token
from all possible paths which could reach this token.
Args:
src_token: The source token connected to current token with an arc.
is_keyword_ilabel: If true, the arc consumes an input label which is
a part of wake word. Otherwhise, the input label is
blank or unknown, i.e. current token is still not part of wake word.
acoustic_score: acoustic score of this arc.
"""
if (
not self.is_active
or self.total_score < src_token.total_score + acoustic_score
):
self.is_active = True
self.total_score = src_token.total_score + acoustic_score
if is_keyword_ilabel:
self.average_keyword_score = (
acoustic_score
+ src_token.average_keyword_score * src_token.keyword_frames
) / (src_token.keyword_frames + 1)
self.keyword_frames = src_token.keyword_frames + 1
else:
self.average_keyword_score = 0.0
class SingleDecodable:
def __init__(
self,
model_output,
keyword_ilabel_start,
graph,
):
"""
Args:
model_output: log_softmax(logit) with shape [T, C]
keyword_ilabel_start: index of the first token of the wake word.
In this recipe, tokens not for wake word has smaller token index,
i.e. blank 0; unk 1.
graph: decoding graph of the wake word.
"""
self.init_token_list = [Token() for i in range(len(graph.state_list))]
self.reset_token_list()
self.model_output = model_output
self.T = model_output.shape[0]
self.utt_score = 0.0
self.current_frame_index = 0
self.keyword_ilabel_start = keyword_ilabel_start
self.graph = graph
self.number_tokens = len(self.cur_token_list)
def reset_token_list(self) -> None:
"""
Reset all tokens to a condition without consuming any acoustic frames.
"""
self.cur_token_list = copy.deepcopy(self.init_token_list)
self.expand_token_list = copy.deepcopy(self.init_token_list)
self.cur_token_list[0].is_active = True
self.cur_token_list[0].total_score = 0
self.cur_token_list[0].average_keyword_score = 0
def process_oneframe(self) -> None:
"""
Decode a frame and update all tokens.
"""
for state_id, cur_token in enumerate(self.cur_token_list):
if cur_token.is_active:
for arc_id in self.graph.state_list[state_id].arc_list:
acoustic_score = self.model_output[self.current_frame_index][
arc_id.ilabel
]
is_keyword_ilabel = arc_id.ilabel >= self.keyword_ilabel_start
self.expand_token_list[arc_id.next_state()].set_token(
cur_token,
is_keyword_ilabel,
acoustic_score,
)
# use best_score to keep total_score in a good range
self.best_state_id = 0
best_score = self.expand_token_list[0].total_score
for state_id in range(self.number_tokens):
if self.expand_token_list[state_id].is_active:
if best_score < self.expand_token_list[state_id].total_score:
best_score = self.expand_token_list[state_id].total_score
self.best_state_id = state_id
self.cur_token_list = self.expand_token_list
for state_id in range(self.number_tokens):
self.cur_token_list[state_id].total_score -= best_score
self.expand_token_list = copy.deepcopy(self.init_token_list)
potential_score = np.exp(
self.cur_token_list[self.graph.final_state_id].average_keyword_score
)
if potential_score > self.utt_score:
self.utt_score = potential_score
self.current_frame_index += 1
def decode_utt(
params: AttributeDict, utt_id: str, post_file, graph: FiniteStateTransducer
) -> Tuple[str, float]:
"""
Decode a single utterance.
Args:
params:
The return value of :func:`get_params`.
utt_id: utt_id to be decoded, used to fetch posterior matrix from post_file.
post_file: file to save posterior for all test set.
graph: decoding graph.
Returns:
utt_id and its corresponding probability to be a wake word.
"""
reader = NumpyHdf5Reader(post_file)
model_output = reader.read(utt_id)
keyword_ilabel_start = params.wakeup_word_tokens[0]
decodable = SingleDecodable(
model_output=model_output,
keyword_ilabel_start=keyword_ilabel_start,
graph=graph,
)
for t in range(decodable.T):
decodable.process_oneframe()
return utt_id, decodable.utt_score
def get_parser():
parser = argparse.ArgumentParser(
description="A simple FST decoder for the wake word detection\n"
)
parser.add_argument(
"--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt"
)
parser.add_argument("--post-h5", help="model output in h5 format")
parser.add_argument("--score-file", help="file to save scores of each utterance")
return parser
def main():
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
)
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
logging.info(f"Graph used:\n{graph.to_str()}")
logging.info("About to load data to decoder.")
with ProcessPoolExecutor() as executor, open(
params.score_file, "w", encoding="utf8"
) as fout:
futures = [
executor.submit(decode_utt, params, key, params.post_h5, graph)
for key in tqdm(keys)
]
logging.info("Decoding.")
for future in tqdm(futures):
k, v = future.result()
fout.write(str(k) + " " + str(v) + "\n")
if __name__ == "__main__":
main()

115
egs/himia/wuw/local/auc.py Executable file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import Dict, Tuple
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from sklearn.metrics import roc_curve, auc
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--positive-score-file", required=True, help="score file of positive data"
)
parser.add_argument(
"--negative-score-file", required=True, help="score file of negative data"
)
parser.add_argument("--legend", required=True, help="utt2dur file of negative data")
return parser.parse_args()
def load_score(score_file: Path) -> Dict[str, float]:
"""
Args:
score_file: Path to score file. Each line has two columns.
The first colume is utt-id, and the second one is score.
This score could be viewed as probability of being wakeup word.
Returns:
A dict with that key is utt-id and value is corresponding score.
"""
pos_dict = {}
with open(score_file, "r", encoding="utf8") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
key = arr[0]
score = float(arr[1])
pos_dict[key] = score
return pos_dict
def get_roc_and_auc(
pos_dict: Dict,
neg_dict: Dict,
) -> Tuple[np.array, np.array, float]:
"""
Args:
pos_dict: scores of positive samples.
neg_dict: scores of negative samples.
Return:
A tuple of three elements, which will be used to plot roc curve.
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
The third element is area under the roc curve(AUC).
"""
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
pos_y = np.ones_like(pos_scores, dtype=int)
neg_y = np.zeros_like(neg_scores, dtype=int)
scores = np.concatenate([pos_scores, neg_scores])
y = np.concatenate([pos_y, neg_y])
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=1)
roc_auc = auc(fpr, tpr)
return fpr, tpr, roc_auc
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = get_args()
logging.basicConfig(format=formatter, level=logging.INFO)
pos_dict = load_score(args.positive_score_file)
neg_dict = load_score(args.negative_score_file)
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
plt.figure(figsize=(16, 9))
plt.plot(fpr, tpr, label=f"{args.legend}(AUC = %1.8f)" % roc_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic(ROC)")
plt.legend(loc="lower right")
output_path = Path(args.positive_score_file).parent
logging.info(f"AUC of {args.legend} {output_path}: {roc_auc}")
plt.savefig(f"{output_path}/{args.legend}.pdf", bbox_inches="tight")
if __name__ == "__main__":
main()