From 17dab02dc98ad9361820a9ea956431b735c020db Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 25 Dec 2023 12:18:32 +0800 Subject: [PATCH] various fixes to context graph to support kws system and bugs of hotwords --- icefall/context_graph.py | 200 ++++++++++++++++++++++++++++++++------- 1 file changed, 164 insertions(+), 36 deletions(-) diff --git a/icefall/context_graph.py b/icefall/context_graph.py index b3d7972a8..52a98f352 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -17,7 +17,7 @@ import os import shutil from collections import deque -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union class ContextState: @@ -31,6 +31,9 @@ class ContextState: node_score: float, output_score: float, is_end: bool, + level: int, + phrase: str = "", + ac_threshold: float = 1.0, ): """Create a ContextState. @@ -51,6 +54,15 @@ class ContextState: the output node for current node. is_end: True if current token is the end of a context. + level: + The distance from current node to root. + phrase: + The context phrase of current state, the value is valid only when + current state is end state (is_end == True). + ac_threshold: + The acoustic threshold (probability) of current context phrase, the + value is valid only when current state is end state (is_end == True). + Note: ac_threshold only used in keywords spotting. """ self.id = id self.token = token @@ -58,7 +70,10 @@ class ContextState: self.node_score = node_score self.output_score = output_score self.is_end = is_end + self.level = level self.next = {} + self.phrase = phrase + self.ac_threshold = ac_threshold self.fail = None self.output = None @@ -75,7 +90,7 @@ class ContextGraph: beam search. """ - def __init__(self, context_score: float): + def __init__(self, context_score: float, ac_threshold: float = 1.0): """Initialize a ContextGraph with the given ``context_score``. A root node will be created (**NOTE:** the token of root is hardcoded to -1). @@ -87,8 +102,12 @@ class ContextGraph: Note: This is just the default score for each token, the users can manually specify the context_score for each word/phrase (i.e. different phrase might have different token score). + ac_threshold: + The acoustic threshold (probability) to trigger the word/phrase, this argument + is used only when applying the graph to keywords spotting system. """ self.context_score = context_score + self.ac_threshold = ac_threshold self.num_nodes = 0 self.root = ContextState( id=self.num_nodes, @@ -97,6 +116,7 @@ class ContextGraph: node_score=0, output_score=0, is_end=False, + level=0, ) self.root.fail = self.root @@ -136,7 +156,13 @@ class ContextGraph: node.output_score += 0 if output is None else output.output_score queue.append(node) - def build(self, token_ids: List[Tuple[List[int], float]]): + def build( + self, + token_ids: List[List[int]], + phrases: Optional[List[str]] = None, + scores: Optional[List[float]] = None, + ac_thresholds: Optional[List[float]] = None, + ): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -145,52 +171,80 @@ class ContextGraph: Args: token_ids: - The given token lists to build the ContextGraph, it is a list of tuple of - token list and its customized score, the token list contains the token ids + The given token lists to build the ContextGraph, it is a list of + token list, the token list contains the token ids for a word/phrase. The token id could be an id of a char (modeling with single Chinese char) or an id of a BPE - (modeling with BPEs). The score is the total score for current token list, + (modeling with BPEs). + phrases: + The given phrases, they are the original text of the token_ids, the + length of `phrases` MUST be equal to the length of `token_ids`. + scores: + The customize boosting score(token level) for each word/phrase, 0 means using the default value (i.e. self.context_score). + It is a list of floats, and the length of `scores` MUST be equal to + the length of `token_ids`. + ac_thresholds: + The customize trigger acoustic threshold (probability) for each phrase, + 0 means using the default value (i.e. self.ac_threshold). It is + used only when this graph applied for the keywords spotting system. + The length of `ac_threshold` MUST be equal to the length of `token_ids`. Note: The phrases would have shared states, the score of the shared states is - the maximum value among all the tokens sharing this state. + the MAXIMUM value among all the tokens sharing this state. """ - for (tokens, score) in token_ids: + num_phrases = len(token_ids) + if phrases is not None: + assert len(phrases) == num_phrases, (len(phrases), num_phrases) + if scores is not None: + assert len(scores) == num_phrases, (len(scores), num_phrases) + if ac_thresholds is not None: + assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases) + + for index, tokens in enumerate(token_ids): + phrase = "" if phrases is None else phrases[index] + score = 0.0 if scores is None else scores[index] + ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index] node = self.root # If has customized score using the customized token score, otherwise # using the default score - context_score = ( - self.context_score if score == 0.0 else round(score / len(tokens), 2) - ) + context_score = self.context_score if score == 0.0 else score + threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold for i, token in enumerate(tokens): node_next = {} if token not in node.next: self.num_nodes += 1 - node_id = self.num_nodes - token_score = context_score is_end = i == len(tokens) - 1 + node_score = node.node_score + context_score + node.next[token] = ContextState( + id=self.num_nodes, + token=token, + token_score=context_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + level=i + 1, + phrase=phrase if is_end else "", + ac_threshold=threshold if is_end else 0.0, + ) else: # node exists, get the score of shared state. token_score = max(context_score, node.next[token].token_score) - node_id = node.next[token].id - node_next = node.next[token].next + node.next[token].token_score = token_score + node_score = node.node_score + token_score + node.next[token].node_score = node_score is_end = i == len(tokens) - 1 or node.next[token].is_end - node_score = node.node_score + token_score - node.next[token] = ContextState( - id=node_id, - token=token, - token_score=token_score, - node_score=node_score, - output_score=node_score if is_end else 0, - is_end=is_end, - ) - node.next[token].next = node_next + node.next[token].output_score = node_score if is_end else 0 + node.next[token].is_end = is_end + if i == len(tokens) - 1: + node.next[token].phrase = phrase + node.next[token].ac_threshold = threshold node = node.next[token] self._fill_fail_output() def forward_one_step( - self, state: ContextState, token: int - ) -> Tuple[float, ContextState]: + self, state: ContextState, token: int, strict_mode: bool = True + ) -> Tuple[float, ContextState, ContextState]: """Search the graph with given state and token. Args: @@ -198,9 +252,27 @@ class ContextGraph: The given token containing trie node to start. token: The given token. + strict_mode: + If the `strict_mode` is True, it can match multiple phrases simultaneously, + and will continue to match longer phrase after matching a shorter one. + If the `strict_mode` is False, it can only match one phrase at a time, + when it matches a phrase, then the state will fall back to root state + (i.e. forgetting all the history state and starting a new match). If + the matched state have multiple outputs (node.output is not None), the + longest phrase will be return. + For example, if the phrases are `he`, `she` and `shell`, the query is + `like shell`, when `strict_mode` is True, the query will match `he` and + `she` at token `e` and `shell` at token `l`, while when `strict_mode` + if False, the query can only match `she`(`she` is longer than `he`, so + `she` not `he`) at token `e`. + Caution: When applying this graph for keywords spotting system, the + `strict_mode` MUST be True. Returns: - Return a tuple of score and next state. + Return a tuple of boosting score for current state, next state and matched + state (if any). Note: Only returns the matched state with longest phrase of + current state, even if there are multiple matches phrases. If no phrase + matched, the matched state is None. """ node = None score = 0 @@ -224,7 +296,31 @@ class ContextGraph: # The score of the fail path score = node.node_score - state.node_score assert node is not None - return (score + node.output_score, node) + + # The matched node of current step, will only return the node with + # longest phrase if there are multiple phrases matches this step. + # None if no matched phrase. + matched_node = ( + node if node.is_end else (node.output if node.output is not None else None) + ) + if not strict_mode and node.output_score != 0: + # output_score != 0 means at least on phrase matched + assert matched_node is not None + output_score = ( + node.node_score + if node.is_end + else ( + node.node_score if node.output is None else node.output.node_score + ) + ) + return (score + output_score - node.node_score, self.root, matched_node) + assert (node.output_score != 0 and matched_node is not None) or ( + node.output_score == 0 and matched_node is None + ), ( + node.output_score, + matched_node, + ) + return (score + node.output_score, node, matched_node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -366,7 +462,7 @@ class ContextGraph: return dot -def _test(queries, score): +def _test(queries, score, strict_mode): contexts_str = [ "S", "HE", @@ -381,11 +477,15 @@ def _test(queries, score): # test default score (1) contexts = [] + scores = [] + phrases = [] for s in contexts_str: - contexts.append(([ord(x) for x in s], score)) + contexts.append([ord(x) for x in s]) + scores.append(round(score / len(s), 2)) + phrases.append(s) context_graph = ContextGraph(context_score=1) - context_graph.build(contexts) + context_graph.build(token_ids=contexts, scores=scores, phrases=phrases) symbol_table = {} for contexts in contexts_str: @@ -402,7 +502,9 @@ def _test(queries, score): total_scores = 0 state = context_graph.root for q in query: - score, state = context_graph.forward_one_step(state, ord(q)) + score, state, phrase = context_graph.forward_one_step( + state, ord(q), strict_mode + ) total_scores += score score, state = context_graph.finalize(state) assert state.token == -1, state.token @@ -427,9 +529,22 @@ if __name__ == "__main__": "DHRHISQ": 4, # "HIS", "S" "THEN": 2, # "HE" } - _test(queries, 0) + _test(queries, 0, True) - # test custom score (5) + queries = { + "HEHERSHE": 7, # "HE", "HE", "S", "HE" + "HERSHE": 5, # "HE", "S", "HE" + "HISHE": 5, # "HIS", "HE" + "SHED": 3, # "S", "HE" + "SHELF": 3, # "S", "HE" + "HELL": 2, # "HE" + "HELLO": 2, # "HE" + "DHRHISQ": 3, # "HIS" + "THEN": 2, # "HE" + } + _test(queries, 0, False) + + # test custom score # S : 5 # HE : 5 (2.5 + 2.5) # SHE : 8.34 (5 + 1.67 + 1.67) @@ -450,4 +565,17 @@ if __name__ == "__main__": "THEN": 5, # "HE" } - _test(queries, 5) + _test(queries, 5, True) + + queries = { + "HEHERSHE": 20, # "HE", "HE", "S", "HE" + "HERSHE": 15, # "HE", "S", "HE" + "HISHE": 10.84, # "HIS", "HE" + "SHED": 10, # "S", "HE" + "SHELF": 10, # "S", "HE" + "HELL": 5, # "HE" + "HELLO": 5, # "HE" + "DHRHISQ": 5.84, # "HIS" + "THEN": 5, # "HE" + } + _test(queries, 5, False)