various fixes to context graph to support kws system and bugs of hotwords

This commit is contained in:
pkufool 2023-12-25 12:18:32 +08:00
parent 11d816d174
commit 17dab02dc9

View File

@ -17,7 +17,7 @@
import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
class ContextState:
@ -31,6 +31,9 @@ class ContextState:
node_score: float,
output_score: float,
is_end: bool,
level: int,
phrase: str = "",
ac_threshold: float = 1.0,
):
"""Create a ContextState.
@ -51,6 +54,15 @@ class ContextState:
the output node for current node.
is_end:
True if current token is the end of a context.
level:
The distance from current node to root.
phrase:
The context phrase of current state, the value is valid only when
current state is end state (is_end == True).
ac_threshold:
The acoustic threshold (probability) of current context phrase, the
value is valid only when current state is end state (is_end == True).
Note: ac_threshold only used in keywords spotting.
"""
self.id = id
self.token = token
@ -58,7 +70,10 @@ class ContextState:
self.node_score = node_score
self.output_score = output_score
self.is_end = is_end
self.level = level
self.next = {}
self.phrase = phrase
self.ac_threshold = ac_threshold
self.fail = None
self.output = None
@ -75,7 +90,7 @@ class ContextGraph:
beam search.
"""
def __init__(self, context_score: float):
def __init__(self, context_score: float, ac_threshold: float = 1.0):
"""Initialize a ContextGraph with the given ``context_score``.
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
@ -87,8 +102,12 @@ class ContextGraph:
Note: This is just the default score for each token, the users can manually
specify the context_score for each word/phrase (i.e. different phrase might
have different token score).
ac_threshold:
The acoustic threshold (probability) to trigger the word/phrase, this argument
is used only when applying the graph to keywords spotting system.
"""
self.context_score = context_score
self.ac_threshold = ac_threshold
self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
@ -97,6 +116,7 @@ class ContextGraph:
node_score=0,
output_score=0,
is_end=False,
level=0,
)
self.root.fail = self.root
@ -136,7 +156,13 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score
queue.append(node)
def build(self, token_ids: List[Tuple[List[int], float]]):
def build(
self,
token_ids: List[List[int]],
phrases: Optional[List[str]] = None,
scores: Optional[List[float]] = None,
ac_thresholds: Optional[List[float]] = None,
):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@ -145,52 +171,80 @@ class ContextGraph:
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of tuple of
token list and its customized score, the token list contains the token ids
The given token lists to build the ContextGraph, it is a list of
token list, the token list contains the token ids
for a word/phrase. The token id could be an id of a char
(modeling with single Chinese char) or an id of a BPE
(modeling with BPEs). The score is the total score for current token list,
(modeling with BPEs).
phrases:
The given phrases, they are the original text of the token_ids, the
length of `phrases` MUST be equal to the length of `token_ids`.
scores:
The customize boosting score(token level) for each word/phrase,
0 means using the default value (i.e. self.context_score).
It is a list of floats, and the length of `scores` MUST be equal to
the length of `token_ids`.
ac_thresholds:
The customize trigger acoustic threshold (probability) for each phrase,
0 means using the default value (i.e. self.ac_threshold). It is
used only when this graph applied for the keywords spotting system.
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
Note: The phrases would have shared states, the score of the shared states is
the maximum value among all the tokens sharing this state.
the MAXIMUM value among all the tokens sharing this state.
"""
for (tokens, score) in token_ids:
num_phrases = len(token_ids)
if phrases is not None:
assert len(phrases) == num_phrases, (len(phrases), num_phrases)
if scores is not None:
assert len(scores) == num_phrases, (len(scores), num_phrases)
if ac_thresholds is not None:
assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases)
for index, tokens in enumerate(token_ids):
phrase = "" if phrases is None else phrases[index]
score = 0.0 if scores is None else scores[index]
ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index]
node = self.root
# If has customized score using the customized token score, otherwise
# using the default score
context_score = (
self.context_score if score == 0.0 else round(score / len(tokens), 2)
)
context_score = self.context_score if score == 0.0 else score
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
for i, token in enumerate(tokens):
node_next = {}
if token not in node.next:
self.num_nodes += 1
node_id = self.num_nodes
token_score = context_score
is_end = i == len(tokens) - 1
node_score = node.node_score + context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=context_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
level=i + 1,
phrase=phrase if is_end else "",
ac_threshold=threshold if is_end else 0.0,
)
else:
# node exists, get the score of shared state.
token_score = max(context_score, node.next[token].token_score)
node_id = node.next[token].id
node_next = node.next[token].next
node.next[token].token_score = token_score
node_score = node.node_score + token_score
node.next[token].node_score = node_score
is_end = i == len(tokens) - 1 or node.next[token].is_end
node_score = node.node_score + token_score
node.next[token] = ContextState(
id=node_id,
token=token,
token_score=token_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node.next[token].next = node_next
node.next[token].output_score = node_score if is_end else 0
node.next[token].is_end = is_end
if i == len(tokens) - 1:
node.next[token].phrase = phrase
node.next[token].ac_threshold = threshold
node = node.next[token]
self._fill_fail_output()
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
self, state: ContextState, token: int, strict_mode: bool = True
) -> Tuple[float, ContextState, ContextState]:
"""Search the graph with given state and token.
Args:
@ -198,9 +252,27 @@ class ContextGraph:
The given token containing trie node to start.
token:
The given token.
strict_mode:
If the `strict_mode` is True, it can match multiple phrases simultaneously,
and will continue to match longer phrase after matching a shorter one.
If the `strict_mode` is False, it can only match one phrase at a time,
when it matches a phrase, then the state will fall back to root state
(i.e. forgetting all the history state and starting a new match). If
the matched state have multiple outputs (node.output is not None), the
longest phrase will be return.
For example, if the phrases are `he`, `she` and `shell`, the query is
`like shell`, when `strict_mode` is True, the query will match `he` and
`she` at token `e` and `shell` at token `l`, while when `strict_mode`
if False, the query can only match `she`(`she` is longer than `he`, so
`she` not `he`) at token `e`.
Caution: When applying this graph for keywords spotting system, the
`strict_mode` MUST be True.
Returns:
Return a tuple of score and next state.
Return a tuple of boosting score for current state, next state and matched
state (if any). Note: Only returns the matched state with longest phrase of
current state, even if there are multiple matches phrases. If no phrase
matched, the matched state is None.
"""
node = None
score = 0
@ -224,7 +296,31 @@ class ContextGraph:
# The score of the fail path
score = node.node_score - state.node_score
assert node is not None
return (score + node.output_score, node)
# The matched node of current step, will only return the node with
# longest phrase if there are multiple phrases matches this step.
# None if no matched phrase.
matched_node = (
node if node.is_end else (node.output if node.output is not None else None)
)
if not strict_mode and node.output_score != 0:
# output_score != 0 means at least on phrase matched
assert matched_node is not None
output_score = (
node.node_score
if node.is_end
else (
node.node_score if node.output is None else node.output.node_score
)
)
return (score + output_score - node.node_score, self.root, matched_node)
assert (node.output_score != 0 and matched_node is not None) or (
node.output_score == 0 and matched_node is None
), (
node.output_score,
matched_node,
)
return (score + node.output_score, node, matched_node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
@ -366,7 +462,7 @@ class ContextGraph:
return dot
def _test(queries, score):
def _test(queries, score, strict_mode):
contexts_str = [
"S",
"HE",
@ -381,11 +477,15 @@ def _test(queries, score):
# test default score (1)
contexts = []
scores = []
phrases = []
for s in contexts_str:
contexts.append(([ord(x) for x in s], score))
contexts.append([ord(x) for x in s])
scores.append(round(score / len(s), 2))
phrases.append(s)
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
context_graph.build(token_ids=contexts, scores=scores, phrases=phrases)
symbol_table = {}
for contexts in contexts_str:
@ -402,7 +502,9 @@ def _test(queries, score):
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
score, state, phrase = context_graph.forward_one_step(
state, ord(q), strict_mode
)
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
@ -427,9 +529,22 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
_test(queries, 0)
_test(queries, 0, True)
# test custom score (5)
queries = {
"HEHERSHE": 7, # "HE", "HE", "S", "HE"
"HERSHE": 5, # "HE", "S", "HE"
"HISHE": 5, # "HIS", "HE"
"SHED": 3, # "S", "HE"
"SHELF": 3, # "S", "HE"
"HELL": 2, # "HE"
"HELLO": 2, # "HE"
"DHRHISQ": 3, # "HIS"
"THEN": 2, # "HE"
}
_test(queries, 0, False)
# test custom score
# S : 5
# HE : 5 (2.5 + 2.5)
# SHE : 8.34 (5 + 1.67 + 1.67)
@ -450,4 +565,17 @@ if __name__ == "__main__":
"THEN": 5, # "HE"
}
_test(queries, 5)
_test(queries, 5, True)
queries = {
"HEHERSHE": 20, # "HE", "HE", "S", "HE"
"HERSHE": 15, # "HE", "S", "HE"
"HISHE": 10.84, # "HIS", "HE"
"SHED": 10, # "S", "HE"
"SHELF": 10, # "S", "HE"
"HELL": 5, # "HE"
"HELLO": 5, # "HE"
"DHRHISQ": 5.84, # "HIS"
"THEN": 5, # "HE"
}
_test(queries, 5, False)