2023-03-16 20:03:57 +08:00

314 lines
10 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Tuple
from pathlib import Path
import numpy as np
from lhotse.features.io import NumpyHdf5Reader
from tqdm import tqdm
from icefall.utils import (
AttributeDict,
setup_logger,
)
from train import get_params
from graph import ctc_trivial_decoding_graph
class Arc:
def __init__(
self, src_state: int, dst_state: int, ilabel: int, olabel: int
) -> None:
self.src_state = int(src_state)
self.dst_state = int(dst_state)
self.ilabel = int(ilabel)
self.olabel = int(olabel)
def next_state(self) -> None:
return self.dst_state
class State:
def __init__(self) -> None:
self.arc_list = list()
def add_arc(self, arc: Arc) -> None:
self.arc_list.append(arc)
class FiniteStateTransducer:
"""Represents a decoding graph for wake word detection."""
def __init__(self, graph: str) -> None:
"""
Construct a decoding graph in FST format given string format graph.
Args:
graph: A string format fst. Each arc is separated by "\n".
"""
self.state_list = list()
for arc_str in graph.split("\n"):
arc = arc_str.strip().split()
if len(arc) == 0:
continue
# An arc may contain 1, 2 or 4 elements, with format:
# src_state [dst_state] [ilabel] [olabel]
# 1 and 2 for final state
# 4 for non-final state
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
arc = [int(element) for element in arc]
src_state_id = arc[0]
max_state_id = len(self.state_list) - 1
if len(arc) == 4: # Non-final state
assert max_state_id <= src_state_id, (
f"Fsa must be sorted by src_state, "
f"while {cur_number_states} <= {src_state_id}. Check your graph."
)
if max_state_id < src_state_id:
new_state = State()
self.state_list.append(new_state)
self.state_list[src_state_id].add_arc(
Arc(src_state_id, arc[1], arc[2], arc[3])
)
else:
assert (
max_state_id == src_state_id
), f"Final state seems unreachable. Check your graph."
self.final_state_id = src_state_id
def to_str(self) -> None:
fst_str = ""
number_states = len(self.state_list)
if number_states == 0:
return fst_str
for state_idx in range(number_states):
cur_state = self.state_list[state_idx]
for arc_idx in range(len(cur_state.arc_list)):
cur_arc = cur_state.arc_list[arc_idx]
ilabel = cur_arc.ilabel
olabel = cur_arc.olabel
src_state = cur_arc.src_state
dst_state = cur_arc.dst_state
fst_str += f"{src_state} {dst_state} {ilabel} {olabel}\n"
fst_str += f"{dst_state}\n"
return fst_str
class Token:
def __init__(self) -> None:
self.is_active = False
self.total_score = -float("inf")
self.keyword_frames = 0
self.average_keyword_score = -float("inf")
self.average_max_keyword_score = 0.0
def set_token(
self,
src_token,
is_keyword_ilabel: bool,
acoustic_score: float,
) -> None:
"""
A dynamic programming process computing the highest score for a token
from all possible paths which could reach this token.
Args:
src_token: The source token connected to current token with an arc.
is_keyword_ilabel: If true, the arc consumes an input label which is
a part of wake word. Otherwhise, the input label is
blank or unknown, i.e. current token is still not part of wake word.
acoustic_score: acoustic score of this arc.
"""
if (
not self.is_active
or self.total_score < src_token.total_score + acoustic_score
):
self.is_active = True
self.total_score = src_token.total_score + acoustic_score
if is_keyword_ilabel:
self.average_keyword_score = (
acoustic_score
+ src_token.average_keyword_score * src_token.keyword_frames
) / (src_token.keyword_frames + 1)
self.keyword_frames = src_token.keyword_frames + 1
else:
self.average_keyword_score = 0.0
class SingleDecodable:
def __init__(
self,
model_output,
keyword_ilabel_start,
graph,
):
"""
Args:
model_output: log_softmax(logit) with shape [T, C]
keyword_ilabel_start: index of the first token of the wake word.
In this recipe, tokens not for wake word has smaller token index,
i.e. blank 0; unk 1.
graph: decoding graph of the wake word.
"""
self.init_token_list = [Token() for i in range(len(graph.state_list))]
self.reset_token_list()
self.model_output = model_output
self.T = model_output.shape[0]
self.utt_score = 0.0
self.current_frame_index = 0
self.keyword_ilabel_start = keyword_ilabel_start
self.graph = graph
self.number_tokens = len(self.cur_token_list)
def reset_token_list(self) -> None:
"""
Reset all tokens to a condition without consuming any acoustic frames.
"""
self.cur_token_list = copy.deepcopy(self.init_token_list)
self.expand_token_list = copy.deepcopy(self.init_token_list)
self.cur_token_list[0].is_active = True
self.cur_token_list[0].total_score = 0
self.cur_token_list[0].average_keyword_score = 0
def process_oneframe(self) -> None:
"""
Decode a frame and update all tokens.
"""
for state_id, cur_token in enumerate(self.cur_token_list):
if cur_token.is_active:
for arc_id in self.graph.state_list[state_id].arc_list:
acoustic_score = self.model_output[self.current_frame_index][
arc_id.ilabel
]
is_keyword_ilabel = arc_id.ilabel >= self.keyword_ilabel_start
self.expand_token_list[arc_id.next_state()].set_token(
cur_token,
is_keyword_ilabel,
acoustic_score,
)
# use best_score to keep total_score in a good range
self.best_state_id = 0
best_score = self.expand_token_list[0].total_score
for state_id in range(self.number_tokens):
if self.expand_token_list[state_id].is_active:
if best_score < self.expand_token_list[state_id].total_score:
best_score = self.expand_token_list[state_id].total_score
self.best_state_id = state_id
self.cur_token_list = self.expand_token_list
for state_id in range(self.number_tokens):
self.cur_token_list[state_id].total_score -= best_score
self.expand_token_list = copy.deepcopy(self.init_token_list)
potential_score = np.exp(
self.cur_token_list[self.graph.final_state_id].average_keyword_score
)
if potential_score > self.utt_score:
self.utt_score = potential_score
self.current_frame_index += 1
def decode_utt(
params: AttributeDict, utt_id: str, post_file, graph: FiniteStateTransducer
) -> Tuple[str, float]:
"""
Decode a single utterance.
Args:
params:
The return value of :func:`get_params`.
utt_id: utt_id to be decoded, used to fetch posterior matrix from post_file.
post_file: file to save posterior for all test set.
graph: decoding graph.
Returns:
utt_id and its corresponding probability to be a wake word.
"""
reader = NumpyHdf5Reader(post_file)
model_output = reader.read(utt_id)
keyword_ilabel_start = params.wakeup_word_tokens[0]
decodable = SingleDecodable(
model_output=model_output,
keyword_ilabel_start=keyword_ilabel_start,
graph=graph,
)
for t in range(decodable.T):
decodable.process_oneframe()
return utt_id, decodable.utt_score
def get_parser():
parser = argparse.ArgumentParser(
description="A simple FST decoder for the wake word detection\n"
)
parser.add_argument(
"--post-h5",
type=str,
help="model output in h5 format",
)
parser.add_argument(
"--score-file",
type=str,
help="file to save scores of each utterance",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
post_dir = Path(params.post_h5).parent
test_set = Path(params.post_h5).stem
setup_logger(f"{post_dir}/log/log-decode-{test_set}")
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
logging.info(f"Graph used:\n{graph.to_str()}")
logging.info(f"About to load {test_set}.")
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
with ProcessPoolExecutor() as executor, open(
params.score_file, "w", encoding="utf8"
) as fout:
futures = [
executor.submit(decode_utt, params, key, params.post_h5, graph)
for key in tqdm(keys)
]
logging.info(f"Decoding {test_set}.")
for future in tqdm(futures):
k, v = future.result()
fout.write(str(k) + " " + str(v) + "\n")
logging.info(f"Finish decoding {test_set}.")
if __name__ == "__main__":
main()