mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
various fixes to context graph to support kws system and bugs of hotwords
This commit is contained in:
parent
11d816d174
commit
17dab02dc9
@ -17,7 +17,7 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class ContextState:
|
||||
@ -31,6 +31,9 @@ class ContextState:
|
||||
node_score: float,
|
||||
output_score: float,
|
||||
is_end: bool,
|
||||
level: int,
|
||||
phrase: str = "",
|
||||
ac_threshold: float = 1.0,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
|
||||
@ -51,6 +54,15 @@ class ContextState:
|
||||
the output node for current node.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
level:
|
||||
The distance from current node to root.
|
||||
phrase:
|
||||
The context phrase of current state, the value is valid only when
|
||||
current state is end state (is_end == True).
|
||||
ac_threshold:
|
||||
The acoustic threshold (probability) of current context phrase, the
|
||||
value is valid only when current state is end state (is_end == True).
|
||||
Note: ac_threshold only used in keywords spotting.
|
||||
"""
|
||||
self.id = id
|
||||
self.token = token
|
||||
@ -58,7 +70,10 @@ class ContextState:
|
||||
self.node_score = node_score
|
||||
self.output_score = output_score
|
||||
self.is_end = is_end
|
||||
self.level = level
|
||||
self.next = {}
|
||||
self.phrase = phrase
|
||||
self.ac_threshold = ac_threshold
|
||||
self.fail = None
|
||||
self.output = None
|
||||
|
||||
@ -75,7 +90,7 @@ class ContextGraph:
|
||||
beam search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_score: float):
|
||||
def __init__(self, context_score: float, ac_threshold: float = 1.0):
|
||||
"""Initialize a ContextGraph with the given ``context_score``.
|
||||
|
||||
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||
@ -87,8 +102,12 @@ class ContextGraph:
|
||||
Note: This is just the default score for each token, the users can manually
|
||||
specify the context_score for each word/phrase (i.e. different phrase might
|
||||
have different token score).
|
||||
ac_threshold:
|
||||
The acoustic threshold (probability) to trigger the word/phrase, this argument
|
||||
is used only when applying the graph to keywords spotting system.
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.ac_threshold = ac_threshold
|
||||
self.num_nodes = 0
|
||||
self.root = ContextState(
|
||||
id=self.num_nodes,
|
||||
@ -97,6 +116,7 @@ class ContextGraph:
|
||||
node_score=0,
|
||||
output_score=0,
|
||||
is_end=False,
|
||||
level=0,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
|
||||
@ -136,7 +156,13 @@ class ContextGraph:
|
||||
node.output_score += 0 if output is None else output.output_score
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[Tuple[List[int], float]]):
|
||||
def build(
|
||||
self,
|
||||
token_ids: List[List[int]],
|
||||
phrases: Optional[List[str]] = None,
|
||||
scores: Optional[List[float]] = None,
|
||||
ac_thresholds: Optional[List[float]] = None,
|
||||
):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
@ -145,52 +171,80 @@ class ContextGraph:
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
The given token lists to build the ContextGraph, it is a list of tuple of
|
||||
token list and its customized score, the token list contains the token ids
|
||||
The given token lists to build the ContextGraph, it is a list of
|
||||
token list, the token list contains the token ids
|
||||
for a word/phrase. The token id could be an id of a char
|
||||
(modeling with single Chinese char) or an id of a BPE
|
||||
(modeling with BPEs). The score is the total score for current token list,
|
||||
(modeling with BPEs).
|
||||
phrases:
|
||||
The given phrases, they are the original text of the token_ids, the
|
||||
length of `phrases` MUST be equal to the length of `token_ids`.
|
||||
scores:
|
||||
The customize boosting score(token level) for each word/phrase,
|
||||
0 means using the default value (i.e. self.context_score).
|
||||
It is a list of floats, and the length of `scores` MUST be equal to
|
||||
the length of `token_ids`.
|
||||
ac_thresholds:
|
||||
The customize trigger acoustic threshold (probability) for each phrase,
|
||||
0 means using the default value (i.e. self.ac_threshold). It is
|
||||
used only when this graph applied for the keywords spotting system.
|
||||
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
|
||||
|
||||
Note: The phrases would have shared states, the score of the shared states is
|
||||
the maximum value among all the tokens sharing this state.
|
||||
the MAXIMUM value among all the tokens sharing this state.
|
||||
"""
|
||||
for (tokens, score) in token_ids:
|
||||
num_phrases = len(token_ids)
|
||||
if phrases is not None:
|
||||
assert len(phrases) == num_phrases, (len(phrases), num_phrases)
|
||||
if scores is not None:
|
||||
assert len(scores) == num_phrases, (len(scores), num_phrases)
|
||||
if ac_thresholds is not None:
|
||||
assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases)
|
||||
|
||||
for index, tokens in enumerate(token_ids):
|
||||
phrase = "" if phrases is None else phrases[index]
|
||||
score = 0.0 if scores is None else scores[index]
|
||||
ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index]
|
||||
node = self.root
|
||||
# If has customized score using the customized token score, otherwise
|
||||
# using the default score
|
||||
context_score = (
|
||||
self.context_score if score == 0.0 else round(score / len(tokens), 2)
|
||||
)
|
||||
context_score = self.context_score if score == 0.0 else score
|
||||
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
|
||||
for i, token in enumerate(tokens):
|
||||
node_next = {}
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
node_id = self.num_nodes
|
||||
token_score = context_score
|
||||
is_end = i == len(tokens) - 1
|
||||
node_score = node.node_score + context_score
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=context_score,
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
level=i + 1,
|
||||
phrase=phrase if is_end else "",
|
||||
ac_threshold=threshold if is_end else 0.0,
|
||||
)
|
||||
else:
|
||||
# node exists, get the score of shared state.
|
||||
token_score = max(context_score, node.next[token].token_score)
|
||||
node_id = node.next[token].id
|
||||
node_next = node.next[token].next
|
||||
node.next[token].token_score = token_score
|
||||
node_score = node.node_score + token_score
|
||||
node.next[token].node_score = node_score
|
||||
is_end = i == len(tokens) - 1 or node.next[token].is_end
|
||||
node_score = node.node_score + token_score
|
||||
node.next[token] = ContextState(
|
||||
id=node_id,
|
||||
token=token,
|
||||
token_score=token_score,
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
)
|
||||
node.next[token].next = node_next
|
||||
node.next[token].output_score = node_score if is_end else 0
|
||||
node.next[token].is_end = is_end
|
||||
if i == len(tokens) - 1:
|
||||
node.next[token].phrase = phrase
|
||||
node.next[token].ac_threshold = threshold
|
||||
node = node.next[token]
|
||||
self._fill_fail_output()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
) -> Tuple[float, ContextState]:
|
||||
self, state: ContextState, token: int, strict_mode: bool = True
|
||||
) -> Tuple[float, ContextState, ContextState]:
|
||||
"""Search the graph with given state and token.
|
||||
|
||||
Args:
|
||||
@ -198,9 +252,27 @@ class ContextGraph:
|
||||
The given token containing trie node to start.
|
||||
token:
|
||||
The given token.
|
||||
strict_mode:
|
||||
If the `strict_mode` is True, it can match multiple phrases simultaneously,
|
||||
and will continue to match longer phrase after matching a shorter one.
|
||||
If the `strict_mode` is False, it can only match one phrase at a time,
|
||||
when it matches a phrase, then the state will fall back to root state
|
||||
(i.e. forgetting all the history state and starting a new match). If
|
||||
the matched state have multiple outputs (node.output is not None), the
|
||||
longest phrase will be return.
|
||||
For example, if the phrases are `he`, `she` and `shell`, the query is
|
||||
`like shell`, when `strict_mode` is True, the query will match `he` and
|
||||
`she` at token `e` and `shell` at token `l`, while when `strict_mode`
|
||||
if False, the query can only match `she`(`she` is longer than `he`, so
|
||||
`she` not `he`) at token `e`.
|
||||
Caution: When applying this graph for keywords spotting system, the
|
||||
`strict_mode` MUST be True.
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state.
|
||||
Return a tuple of boosting score for current state, next state and matched
|
||||
state (if any). Note: Only returns the matched state with longest phrase of
|
||||
current state, even if there are multiple matches phrases. If no phrase
|
||||
matched, the matched state is None.
|
||||
"""
|
||||
node = None
|
||||
score = 0
|
||||
@ -224,7 +296,31 @@ class ContextGraph:
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.node_score
|
||||
assert node is not None
|
||||
return (score + node.output_score, node)
|
||||
|
||||
# The matched node of current step, will only return the node with
|
||||
# longest phrase if there are multiple phrases matches this step.
|
||||
# None if no matched phrase.
|
||||
matched_node = (
|
||||
node if node.is_end else (node.output if node.output is not None else None)
|
||||
)
|
||||
if not strict_mode and node.output_score != 0:
|
||||
# output_score != 0 means at least on phrase matched
|
||||
assert matched_node is not None
|
||||
output_score = (
|
||||
node.node_score
|
||||
if node.is_end
|
||||
else (
|
||||
node.node_score if node.output is None else node.output.node_score
|
||||
)
|
||||
)
|
||||
return (score + output_score - node.node_score, self.root, matched_node)
|
||||
assert (node.output_score != 0 and matched_node is not None) or (
|
||||
node.output_score == 0 and matched_node is None
|
||||
), (
|
||||
node.output_score,
|
||||
matched_node,
|
||||
)
|
||||
return (score + node.output_score, node, matched_node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
@ -366,7 +462,7 @@ class ContextGraph:
|
||||
return dot
|
||||
|
||||
|
||||
def _test(queries, score):
|
||||
def _test(queries, score, strict_mode):
|
||||
contexts_str = [
|
||||
"S",
|
||||
"HE",
|
||||
@ -381,11 +477,15 @@ def _test(queries, score):
|
||||
|
||||
# test default score (1)
|
||||
contexts = []
|
||||
scores = []
|
||||
phrases = []
|
||||
for s in contexts_str:
|
||||
contexts.append(([ord(x) for x in s], score))
|
||||
contexts.append([ord(x) for x in s])
|
||||
scores.append(round(score / len(s), 2))
|
||||
phrases.append(s)
|
||||
|
||||
context_graph = ContextGraph(context_score=1)
|
||||
context_graph.build(contexts)
|
||||
context_graph.build(token_ids=contexts, scores=scores, phrases=phrases)
|
||||
|
||||
symbol_table = {}
|
||||
for contexts in contexts_str:
|
||||
@ -402,7 +502,9 @@ def _test(queries, score):
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
score, state, phrase = context_graph.forward_one_step(
|
||||
state, ord(q), strict_mode
|
||||
)
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
@ -427,9 +529,22 @@ if __name__ == "__main__":
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
_test(queries, 0)
|
||||
_test(queries, 0, True)
|
||||
|
||||
# test custom score (5)
|
||||
queries = {
|
||||
"HEHERSHE": 7, # "HE", "HE", "S", "HE"
|
||||
"HERSHE": 5, # "HE", "S", "HE"
|
||||
"HISHE": 5, # "HIS", "HE"
|
||||
"SHED": 3, # "S", "HE"
|
||||
"SHELF": 3, # "S", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 2, # "HE"
|
||||
"DHRHISQ": 3, # "HIS"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
_test(queries, 0, False)
|
||||
|
||||
# test custom score
|
||||
# S : 5
|
||||
# HE : 5 (2.5 + 2.5)
|
||||
# SHE : 8.34 (5 + 1.67 + 1.67)
|
||||
@ -450,4 +565,17 @@ if __name__ == "__main__":
|
||||
"THEN": 5, # "HE"
|
||||
}
|
||||
|
||||
_test(queries, 5)
|
||||
_test(queries, 5, True)
|
||||
|
||||
queries = {
|
||||
"HEHERSHE": 20, # "HE", "HE", "S", "HE"
|
||||
"HERSHE": 15, # "HE", "S", "HE"
|
||||
"HISHE": 10.84, # "HIS", "HE"
|
||||
"SHED": 10, # "S", "HE"
|
||||
"SHELF": 10, # "S", "HE"
|
||||
"HELL": 5, # "HE"
|
||||
"HELLO": 5, # "HE"
|
||||
"DHRHISQ": 5.84, # "HIS"
|
||||
"THEN": 5, # "HE"
|
||||
}
|
||||
_test(queries, 5, False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user