mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
various fixes to context graph to support kws system and bugs of hotwords
This commit is contained in:
parent
11d816d174
commit
17dab02dc9
@ -17,7 +17,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
class ContextState:
|
class ContextState:
|
||||||
@ -31,6 +31,9 @@ class ContextState:
|
|||||||
node_score: float,
|
node_score: float,
|
||||||
output_score: float,
|
output_score: float,
|
||||||
is_end: bool,
|
is_end: bool,
|
||||||
|
level: int,
|
||||||
|
phrase: str = "",
|
||||||
|
ac_threshold: float = 1.0,
|
||||||
):
|
):
|
||||||
"""Create a ContextState.
|
"""Create a ContextState.
|
||||||
|
|
||||||
@ -51,6 +54,15 @@ class ContextState:
|
|||||||
the output node for current node.
|
the output node for current node.
|
||||||
is_end:
|
is_end:
|
||||||
True if current token is the end of a context.
|
True if current token is the end of a context.
|
||||||
|
level:
|
||||||
|
The distance from current node to root.
|
||||||
|
phrase:
|
||||||
|
The context phrase of current state, the value is valid only when
|
||||||
|
current state is end state (is_end == True).
|
||||||
|
ac_threshold:
|
||||||
|
The acoustic threshold (probability) of current context phrase, the
|
||||||
|
value is valid only when current state is end state (is_end == True).
|
||||||
|
Note: ac_threshold only used in keywords spotting.
|
||||||
"""
|
"""
|
||||||
self.id = id
|
self.id = id
|
||||||
self.token = token
|
self.token = token
|
||||||
@ -58,7 +70,10 @@ class ContextState:
|
|||||||
self.node_score = node_score
|
self.node_score = node_score
|
||||||
self.output_score = output_score
|
self.output_score = output_score
|
||||||
self.is_end = is_end
|
self.is_end = is_end
|
||||||
|
self.level = level
|
||||||
self.next = {}
|
self.next = {}
|
||||||
|
self.phrase = phrase
|
||||||
|
self.ac_threshold = ac_threshold
|
||||||
self.fail = None
|
self.fail = None
|
||||||
self.output = None
|
self.output = None
|
||||||
|
|
||||||
@ -75,7 +90,7 @@ class ContextGraph:
|
|||||||
beam search.
|
beam search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, context_score: float):
|
def __init__(self, context_score: float, ac_threshold: float = 1.0):
|
||||||
"""Initialize a ContextGraph with the given ``context_score``.
|
"""Initialize a ContextGraph with the given ``context_score``.
|
||||||
|
|
||||||
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||||
@ -87,8 +102,12 @@ class ContextGraph:
|
|||||||
Note: This is just the default score for each token, the users can manually
|
Note: This is just the default score for each token, the users can manually
|
||||||
specify the context_score for each word/phrase (i.e. different phrase might
|
specify the context_score for each word/phrase (i.e. different phrase might
|
||||||
have different token score).
|
have different token score).
|
||||||
|
ac_threshold:
|
||||||
|
The acoustic threshold (probability) to trigger the word/phrase, this argument
|
||||||
|
is used only when applying the graph to keywords spotting system.
|
||||||
"""
|
"""
|
||||||
self.context_score = context_score
|
self.context_score = context_score
|
||||||
|
self.ac_threshold = ac_threshold
|
||||||
self.num_nodes = 0
|
self.num_nodes = 0
|
||||||
self.root = ContextState(
|
self.root = ContextState(
|
||||||
id=self.num_nodes,
|
id=self.num_nodes,
|
||||||
@ -97,6 +116,7 @@ class ContextGraph:
|
|||||||
node_score=0,
|
node_score=0,
|
||||||
output_score=0,
|
output_score=0,
|
||||||
is_end=False,
|
is_end=False,
|
||||||
|
level=0,
|
||||||
)
|
)
|
||||||
self.root.fail = self.root
|
self.root.fail = self.root
|
||||||
|
|
||||||
@ -136,7 +156,13 @@ class ContextGraph:
|
|||||||
node.output_score += 0 if output is None else output.output_score
|
node.output_score += 0 if output is None else output.output_score
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
def build(self, token_ids: List[Tuple[List[int], float]]):
|
def build(
|
||||||
|
self,
|
||||||
|
token_ids: List[List[int]],
|
||||||
|
phrases: Optional[List[str]] = None,
|
||||||
|
scores: Optional[List[float]] = None,
|
||||||
|
ac_thresholds: Optional[List[float]] = None,
|
||||||
|
):
|
||||||
"""Build the ContextGraph from a list of token list.
|
"""Build the ContextGraph from a list of token list.
|
||||||
It first build a trie from the given token lists, then fill the fail arc
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
for each trie node.
|
for each trie node.
|
||||||
@ -145,52 +171,80 @@ class ContextGraph:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids:
|
token_ids:
|
||||||
The given token lists to build the ContextGraph, it is a list of tuple of
|
The given token lists to build the ContextGraph, it is a list of
|
||||||
token list and its customized score, the token list contains the token ids
|
token list, the token list contains the token ids
|
||||||
for a word/phrase. The token id could be an id of a char
|
for a word/phrase. The token id could be an id of a char
|
||||||
(modeling with single Chinese char) or an id of a BPE
|
(modeling with single Chinese char) or an id of a BPE
|
||||||
(modeling with BPEs). The score is the total score for current token list,
|
(modeling with BPEs).
|
||||||
|
phrases:
|
||||||
|
The given phrases, they are the original text of the token_ids, the
|
||||||
|
length of `phrases` MUST be equal to the length of `token_ids`.
|
||||||
|
scores:
|
||||||
|
The customize boosting score(token level) for each word/phrase,
|
||||||
0 means using the default value (i.e. self.context_score).
|
0 means using the default value (i.e. self.context_score).
|
||||||
|
It is a list of floats, and the length of `scores` MUST be equal to
|
||||||
|
the length of `token_ids`.
|
||||||
|
ac_thresholds:
|
||||||
|
The customize trigger acoustic threshold (probability) for each phrase,
|
||||||
|
0 means using the default value (i.e. self.ac_threshold). It is
|
||||||
|
used only when this graph applied for the keywords spotting system.
|
||||||
|
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
|
||||||
|
|
||||||
Note: The phrases would have shared states, the score of the shared states is
|
Note: The phrases would have shared states, the score of the shared states is
|
||||||
the maximum value among all the tokens sharing this state.
|
the MAXIMUM value among all the tokens sharing this state.
|
||||||
"""
|
"""
|
||||||
for (tokens, score) in token_ids:
|
num_phrases = len(token_ids)
|
||||||
|
if phrases is not None:
|
||||||
|
assert len(phrases) == num_phrases, (len(phrases), num_phrases)
|
||||||
|
if scores is not None:
|
||||||
|
assert len(scores) == num_phrases, (len(scores), num_phrases)
|
||||||
|
if ac_thresholds is not None:
|
||||||
|
assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases)
|
||||||
|
|
||||||
|
for index, tokens in enumerate(token_ids):
|
||||||
|
phrase = "" if phrases is None else phrases[index]
|
||||||
|
score = 0.0 if scores is None else scores[index]
|
||||||
|
ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index]
|
||||||
node = self.root
|
node = self.root
|
||||||
# If has customized score using the customized token score, otherwise
|
# If has customized score using the customized token score, otherwise
|
||||||
# using the default score
|
# using the default score
|
||||||
context_score = (
|
context_score = self.context_score if score == 0.0 else score
|
||||||
self.context_score if score == 0.0 else round(score / len(tokens), 2)
|
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
|
||||||
)
|
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
node_next = {}
|
node_next = {}
|
||||||
if token not in node.next:
|
if token not in node.next:
|
||||||
self.num_nodes += 1
|
self.num_nodes += 1
|
||||||
node_id = self.num_nodes
|
|
||||||
token_score = context_score
|
|
||||||
is_end = i == len(tokens) - 1
|
is_end = i == len(tokens) - 1
|
||||||
|
node_score = node.node_score + context_score
|
||||||
|
node.next[token] = ContextState(
|
||||||
|
id=self.num_nodes,
|
||||||
|
token=token,
|
||||||
|
token_score=context_score,
|
||||||
|
node_score=node_score,
|
||||||
|
output_score=node_score if is_end else 0,
|
||||||
|
is_end=is_end,
|
||||||
|
level=i + 1,
|
||||||
|
phrase=phrase if is_end else "",
|
||||||
|
ac_threshold=threshold if is_end else 0.0,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# node exists, get the score of shared state.
|
# node exists, get the score of shared state.
|
||||||
token_score = max(context_score, node.next[token].token_score)
|
token_score = max(context_score, node.next[token].token_score)
|
||||||
node_id = node.next[token].id
|
node.next[token].token_score = token_score
|
||||||
node_next = node.next[token].next
|
node_score = node.node_score + token_score
|
||||||
|
node.next[token].node_score = node_score
|
||||||
is_end = i == len(tokens) - 1 or node.next[token].is_end
|
is_end = i == len(tokens) - 1 or node.next[token].is_end
|
||||||
node_score = node.node_score + token_score
|
node.next[token].output_score = node_score if is_end else 0
|
||||||
node.next[token] = ContextState(
|
node.next[token].is_end = is_end
|
||||||
id=node_id,
|
if i == len(tokens) - 1:
|
||||||
token=token,
|
node.next[token].phrase = phrase
|
||||||
token_score=token_score,
|
node.next[token].ac_threshold = threshold
|
||||||
node_score=node_score,
|
|
||||||
output_score=node_score if is_end else 0,
|
|
||||||
is_end=is_end,
|
|
||||||
)
|
|
||||||
node.next[token].next = node_next
|
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
self._fill_fail_output()
|
self._fill_fail_output()
|
||||||
|
|
||||||
def forward_one_step(
|
def forward_one_step(
|
||||||
self, state: ContextState, token: int
|
self, state: ContextState, token: int, strict_mode: bool = True
|
||||||
) -> Tuple[float, ContextState]:
|
) -> Tuple[float, ContextState, ContextState]:
|
||||||
"""Search the graph with given state and token.
|
"""Search the graph with given state and token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -198,9 +252,27 @@ class ContextGraph:
|
|||||||
The given token containing trie node to start.
|
The given token containing trie node to start.
|
||||||
token:
|
token:
|
||||||
The given token.
|
The given token.
|
||||||
|
strict_mode:
|
||||||
|
If the `strict_mode` is True, it can match multiple phrases simultaneously,
|
||||||
|
and will continue to match longer phrase after matching a shorter one.
|
||||||
|
If the `strict_mode` is False, it can only match one phrase at a time,
|
||||||
|
when it matches a phrase, then the state will fall back to root state
|
||||||
|
(i.e. forgetting all the history state and starting a new match). If
|
||||||
|
the matched state have multiple outputs (node.output is not None), the
|
||||||
|
longest phrase will be return.
|
||||||
|
For example, if the phrases are `he`, `she` and `shell`, the query is
|
||||||
|
`like shell`, when `strict_mode` is True, the query will match `he` and
|
||||||
|
`she` at token `e` and `shell` at token `l`, while when `strict_mode`
|
||||||
|
if False, the query can only match `she`(`she` is longer than `he`, so
|
||||||
|
`she` not `he`) at token `e`.
|
||||||
|
Caution: When applying this graph for keywords spotting system, the
|
||||||
|
`strict_mode` MUST be True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple of score and next state.
|
Return a tuple of boosting score for current state, next state and matched
|
||||||
|
state (if any). Note: Only returns the matched state with longest phrase of
|
||||||
|
current state, even if there are multiple matches phrases. If no phrase
|
||||||
|
matched, the matched state is None.
|
||||||
"""
|
"""
|
||||||
node = None
|
node = None
|
||||||
score = 0
|
score = 0
|
||||||
@ -224,7 +296,31 @@ class ContextGraph:
|
|||||||
# The score of the fail path
|
# The score of the fail path
|
||||||
score = node.node_score - state.node_score
|
score = node.node_score - state.node_score
|
||||||
assert node is not None
|
assert node is not None
|
||||||
return (score + node.output_score, node)
|
|
||||||
|
# The matched node of current step, will only return the node with
|
||||||
|
# longest phrase if there are multiple phrases matches this step.
|
||||||
|
# None if no matched phrase.
|
||||||
|
matched_node = (
|
||||||
|
node if node.is_end else (node.output if node.output is not None else None)
|
||||||
|
)
|
||||||
|
if not strict_mode and node.output_score != 0:
|
||||||
|
# output_score != 0 means at least on phrase matched
|
||||||
|
assert matched_node is not None
|
||||||
|
output_score = (
|
||||||
|
node.node_score
|
||||||
|
if node.is_end
|
||||||
|
else (
|
||||||
|
node.node_score if node.output is None else node.output.node_score
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return (score + output_score - node.node_score, self.root, matched_node)
|
||||||
|
assert (node.output_score != 0 and matched_node is not None) or (
|
||||||
|
node.output_score == 0 and matched_node is None
|
||||||
|
), (
|
||||||
|
node.output_score,
|
||||||
|
matched_node,
|
||||||
|
)
|
||||||
|
return (score + node.output_score, node, matched_node)
|
||||||
|
|
||||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
"""When reaching the end of the decoded sequence, we need to finalize
|
"""When reaching the end of the decoded sequence, we need to finalize
|
||||||
@ -366,7 +462,7 @@ class ContextGraph:
|
|||||||
return dot
|
return dot
|
||||||
|
|
||||||
|
|
||||||
def _test(queries, score):
|
def _test(queries, score, strict_mode):
|
||||||
contexts_str = [
|
contexts_str = [
|
||||||
"S",
|
"S",
|
||||||
"HE",
|
"HE",
|
||||||
@ -381,11 +477,15 @@ def _test(queries, score):
|
|||||||
|
|
||||||
# test default score (1)
|
# test default score (1)
|
||||||
contexts = []
|
contexts = []
|
||||||
|
scores = []
|
||||||
|
phrases = []
|
||||||
for s in contexts_str:
|
for s in contexts_str:
|
||||||
contexts.append(([ord(x) for x in s], score))
|
contexts.append([ord(x) for x in s])
|
||||||
|
scores.append(round(score / len(s), 2))
|
||||||
|
phrases.append(s)
|
||||||
|
|
||||||
context_graph = ContextGraph(context_score=1)
|
context_graph = ContextGraph(context_score=1)
|
||||||
context_graph.build(contexts)
|
context_graph.build(token_ids=contexts, scores=scores, phrases=phrases)
|
||||||
|
|
||||||
symbol_table = {}
|
symbol_table = {}
|
||||||
for contexts in contexts_str:
|
for contexts in contexts_str:
|
||||||
@ -402,7 +502,9 @@ def _test(queries, score):
|
|||||||
total_scores = 0
|
total_scores = 0
|
||||||
state = context_graph.root
|
state = context_graph.root
|
||||||
for q in query:
|
for q in query:
|
||||||
score, state = context_graph.forward_one_step(state, ord(q))
|
score, state, phrase = context_graph.forward_one_step(
|
||||||
|
state, ord(q), strict_mode
|
||||||
|
)
|
||||||
total_scores += score
|
total_scores += score
|
||||||
score, state = context_graph.finalize(state)
|
score, state = context_graph.finalize(state)
|
||||||
assert state.token == -1, state.token
|
assert state.token == -1, state.token
|
||||||
@ -427,9 +529,22 @@ if __name__ == "__main__":
|
|||||||
"DHRHISQ": 4, # "HIS", "S"
|
"DHRHISQ": 4, # "HIS", "S"
|
||||||
"THEN": 2, # "HE"
|
"THEN": 2, # "HE"
|
||||||
}
|
}
|
||||||
_test(queries, 0)
|
_test(queries, 0, True)
|
||||||
|
|
||||||
# test custom score (5)
|
queries = {
|
||||||
|
"HEHERSHE": 7, # "HE", "HE", "S", "HE"
|
||||||
|
"HERSHE": 5, # "HE", "S", "HE"
|
||||||
|
"HISHE": 5, # "HIS", "HE"
|
||||||
|
"SHED": 3, # "S", "HE"
|
||||||
|
"SHELF": 3, # "S", "HE"
|
||||||
|
"HELL": 2, # "HE"
|
||||||
|
"HELLO": 2, # "HE"
|
||||||
|
"DHRHISQ": 3, # "HIS"
|
||||||
|
"THEN": 2, # "HE"
|
||||||
|
}
|
||||||
|
_test(queries, 0, False)
|
||||||
|
|
||||||
|
# test custom score
|
||||||
# S : 5
|
# S : 5
|
||||||
# HE : 5 (2.5 + 2.5)
|
# HE : 5 (2.5 + 2.5)
|
||||||
# SHE : 8.34 (5 + 1.67 + 1.67)
|
# SHE : 8.34 (5 + 1.67 + 1.67)
|
||||||
@ -450,4 +565,17 @@ if __name__ == "__main__":
|
|||||||
"THEN": 5, # "HE"
|
"THEN": 5, # "HE"
|
||||||
}
|
}
|
||||||
|
|
||||||
_test(queries, 5)
|
_test(queries, 5, True)
|
||||||
|
|
||||||
|
queries = {
|
||||||
|
"HEHERSHE": 20, # "HE", "HE", "S", "HE"
|
||||||
|
"HERSHE": 15, # "HE", "S", "HE"
|
||||||
|
"HISHE": 10.84, # "HIS", "HE"
|
||||||
|
"SHED": 10, # "S", "HE"
|
||||||
|
"SHELF": 10, # "S", "HE"
|
||||||
|
"HELL": 5, # "HE"
|
||||||
|
"HELLO": 5, # "HE"
|
||||||
|
"DHRHISQ": 5.84, # "HIS"
|
||||||
|
"THEN": 5, # "HE"
|
||||||
|
}
|
||||||
|
_test(queries, 5, False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user