From 40a05810ddd6b687ee8d9015b5837c316ed239bd Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 11 May 2023 14:23:48 +0800 Subject: [PATCH] Fixes to forward_one_step; add draw to context graph --- .../pruned_transducer_stateless4/decode.py | 2 +- .../pruned_transducer_stateless5/decode.py | 2 +- icefall/context_graph.py | 251 ++++++++++++++---- 3 files changed, 195 insertions(+), 60 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 6a8e82fff..6aacd7f92 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -919,7 +919,7 @@ def main(): for line in open(params.context_file).readlines(): contexts.append(line.strip()) context_graph = ContextGraph(params.context_score) - context_graph.build_context_graph(sp.encode(contexts)) + context_graph.build(sp.encode(contexts)) else: context_graph = None else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 799992ae6..e2d5eae18 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -855,7 +855,7 @@ def main(): for line in open(params.context_file).readlines(): contexts.append(graph_compiler.texts_to_ids(line.strip())) context_graph = ContextGraph(params.context_score) - context_graph.build_context_graph(contexts) + context_graph.build(contexts) else: context_graph = None else: diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 1936f43b5..61eb5090c 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +import os +import shutil +from collections import deque +from typing import Dict, List, Optional, Tuple class ContextState: @@ -22,28 +25,39 @@ class ContextState: def __init__( self, + id: int, token: int, - score: float, - total_score: float, + token_score: float, + node_score: float, + local_node_score: float, is_end: bool, ): """Create a ContextState. Args: + id: + The node id, only for visualization now. A node is in [0, graph.num_nodes). + The id of the root node is always 0. token: The token id. score: The bonus for each token during decoding, which will hopefully boost the token up to survive beam search. - total_score: + node_score: The accumulated bonus from root of graph to current node, it will be used to calculate the score for fail arc. + local_node_score: + The accumulated bonus from last ``end_node``(node with is_end true) + to current_node, it will be used to calculate the score for fail arc. + Node: The local_node_score of a ``end_node`` is 0. is_end: True if current token is the end of a context. """ + self.id = id self.token = token - self.score = score - self.total_score = total_score + self.token_score = token_score + self.node_score = node_score + self.local_node_score = local_node_score self.is_end = is_end self.next = {} self.fail = None @@ -72,7 +86,15 @@ class ContextGraph: word/phrase will have larger bonus score, they have to be matched though). """ self.context_score = context_score - self.root = ContextState(token=-1, score=0, total_score=0, is_end=False) + self.num_nodes = 0 + self.root = ContextState( + id=self.num_nodes, + token=-1, + token_score=0, + node_score=0, + local_node_score=0, + is_end=False, + ) self.root.fail = self.root def _fill_fail(self): @@ -81,12 +103,12 @@ class ContextGraph: See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the details of the algorithm. """ - queue = [] + queue = deque() for token, node in self.root.next.items(): node.fail = self.root queue.append(node) while queue: - current_node = queue.pop(0) + current_node = queue.popleft() for token, node in current_node.next.items(): fail = current_node.fail if token in fail.next: @@ -102,7 +124,7 @@ class ContextGraph: node.fail = fail queue.append(node) - def build_context_graph(self, token_ids: List[List[int]]): + def build(self, token_ids: List[List[int]]): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -120,13 +142,17 @@ class ContextGraph: node = self.root for i, token in enumerate(tokens): if token not in node.next: + self.num_nodes += 1 + is_end = i == len(tokens) - 1 node.next[token] = ContextState( + id=self.num_nodes, token=token, - score=self.context_score, - # The total score is the accumulated score from root to current node, - # it will be used to calculate the score of fail arc later. - total_score=node.total_score + self.context_score, - is_end=i == len(tokens) - 1, + token_score=self.context_score, + node_score=node.node_score + self.context_score, + local_node_score=0 + if is_end + else (node.local_node_score + self.context_score), + is_end=is_end, ) node = node.next[token] self._fill_fail() @@ -138,7 +164,7 @@ class ContextGraph: Args: state: - The given state (trie node) to start. + The given token containing trie node to start. token: The given token. @@ -148,9 +174,7 @@ class ContextGraph: # token matched if token in state.next: node = state.next[token] - score = node.score - if node.is_end: - node = self.root + score = node.token_score return (score, node) else: # token not matched @@ -164,10 +188,9 @@ class ContextGraph: if token in node.next: node = node.next[token] - # The score of the fail arc - score = node.total_score - state.total_score - if node.is_end: - node = self.root + + # The score of the fail path + score = node.node_score - state.local_node_score return (score, node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: @@ -185,49 +208,161 @@ class ContextGraph: to root. The next state is always root. """ # The score of the fail arc - score = self.root.total_score - state.total_score - if state.is_end: - score = 0 + score = self.root.node_score - state.local_node_score return (score, self.root) + def draw( + self, + title: Optional[str] = None, + filename: Optional[str] = "", + symbol_table: Optional[Dict[int, str]] = None, + ) -> "Digraph": # noqa + + """Visualize a ContextGraph via graphviz. + + Render ContextGraph as an image via graphviz, and return the Digraph object; + and optionally save to file `filename`. + `filename` must have a suffix that graphviz understands, such as + `pdf`, `svg` or `png`. + + Note: + You need to install graphviz to use this function:: + + pip install graphviz + + Args: + title: + Title to be displayed in image, e.g. 'A simple FSA example' + filename: + Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg', + 'foo.png' (must have a suffix that graphviz understands). + symbol_table: + Map the token ids to symbols. + Returns: + A Diagraph from grahpviz. + """ + + try: + import graphviz + except Exception: + print("You cannot use `to_dot` unless the graphviz package is installed.") + raise + + graph_attr = { + "rankdir": "LR", + "size": "8.5,11", + "center": "1", + "orientation": "Portrait", + "ranksep": "0.4", + "nodesep": "0.25", + } + if title is not None: + graph_attr["label"] = title + + default_node_attr = { + "shape": "circle", + "style": "bold", + "fontsize": "14", + } + + final_state_attr = { + "shape": "doublecircle", + "style": "bold", + "fontsize": "14", + } + + final_state = -1 + dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr) + + seen = set() + queue = deque() + queue.append(self.root) + # root id is always 0 + dot.node("0", label="0", **default_node_attr) + dot.edge("0", "0", label=f"*/0") + seen.add(0) + + while len(queue): + current_node = queue.popleft() + for token, node in current_node.next.items(): + if node.id not in seen: + node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") + local_node_score = f"{node.local_node_score:.2f}".rstrip( + "0" + ).rstrip(".") + label = f"{node.id}/({node_score},{local_node_score})" + if node.is_end: + dot.node(str(node.id), label=label, **final_state_attr) + else: + dot.node(str(node.id), label=label, **default_node_attr) + seen.add(node.id) + weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".") + label = str(token) if symbol_table is None else symbol_table[token] + dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}") + dot.edge( + str(node.id), + str(node.fail.id), + color="red", + ) + queue.append(node) + + if filename: + _, extension = os.path.splitext(filename) + if extension == "" or extension[0] != ".": + raise ValueError( + "Filename needs to have a suffix like .png, .pdf, .svg: {}".format( + filename + ) + ) + + import tempfile + + with tempfile.TemporaryDirectory() as tmp_dir: + temp_fn = dot.render( + filename="temp", + directory=tmp_dir, + format=extension[1:], + cleanup=True, + ) + + shutil.move(temp_fn, filename) + + return dot + if __name__ == "__main__": - contexts_str = ["HE", "SHE", "HIS", "HERS"] + contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"] contexts = [] for s in contexts_str: contexts.append([ord(x) for x in s]) - context_graph = ContextGraph(context_score=2) - context_graph.build_context_graph(contexts) + context_graph = ContextGraph(context_score=1) + context_graph.build(contexts) - score, state = context_graph.forward_one_step(context_graph.root, ord("H")) - assert score == 2, score - assert state.token == ord("H"), state.token + symbol_table = {} + for contexts in contexts_str: + for s in contexts: + symbol_table[ord(s)] = s - score, state = context_graph.forward_one_step(state, ord("I")) - assert score == 2, score - assert state.token == ord("I"), state.token + context_graph.draw( + title="Graph for: " + " / ".join(contexts_str), + filename="context_graph.pdf", + symbol_table=symbol_table, + ) - score, state = context_graph.forward_one_step(state, ord("S")) - assert score == 2, score - assert state.token == -1, state.token - - score, state = context_graph.finalize(state) - assert score == 0, score - assert state.token == -1, state.token - - score, state = context_graph.forward_one_step(context_graph.root, ord("S")) - assert score == 2, score - assert state.token == ord("S"), state.token - - score, state = context_graph.forward_one_step(state, ord("H")) - assert score == 2, score - assert state.token == ord("H"), state.token - - score, state = context_graph.forward_one_step(state, ord("D")) - assert score == -4, score - assert state.token == -1, state.token - - score, state = context_graph.forward_one_step(context_graph.root, ord("D")) - assert score == 0, score - assert state.token == -1, state.token + queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"] + expected_scores = [7, 6, 3, 2, 5, 3] + for i, query in enumerate(queries): + total_scores = 0 + state = context_graph.root + for q in query: + score, state = context_graph.forward_one_step(state, ord(q)) + total_scores += score + score, state = context_graph.finalize(state) + assert state.token == -1, state.token + total_scores += score + assert total_scores == expected_scores[i], ( + total_scores, + expected_scores[i], + query, + )