Fixes to forward_one_step; add draw to context graph

This commit is contained in:
pkufool 2023-05-11 14:23:48 +08:00
parent 62557a1564
commit 40a05810dd
3 changed files with 195 additions and 60 deletions

View File

@ -919,7 +919,7 @@ def main():
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append(line.strip()) contexts.append(line.strip())
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(sp.encode(contexts)) context_graph.build(sp.encode(contexts))
else: else:
context_graph = None context_graph = None
else: else:

View File

@ -855,7 +855,7 @@ def main():
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append(graph_compiler.texts_to_ids(line.strip())) contexts.append(graph_compiler.texts_to_ids(line.strip()))
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts) context_graph.build(contexts)
else: else:
context_graph = None context_graph = None
else: else:

View File

@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Tuple import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
class ContextState: class ContextState:
@ -22,28 +25,39 @@ class ContextState:
def __init__( def __init__(
self, self,
id: int,
token: int, token: int,
score: float, token_score: float,
total_score: float, node_score: float,
local_node_score: float,
is_end: bool, is_end: bool,
): ):
"""Create a ContextState. """Create a ContextState.
Args: Args:
id:
The node id, only for visualization now. A node is in [0, graph.num_nodes).
The id of the root node is always 0.
token: token:
The token id. The token id.
score: score:
The bonus for each token during decoding, which will hopefully The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search. boost the token up to survive beam search.
total_score: node_score:
The accumulated bonus from root of graph to current node, it will be The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc. used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
is_end: is_end:
True if current token is the end of a context. True if current token is the end of a context.
""" """
self.id = id
self.token = token self.token = token
self.score = score self.token_score = token_score
self.total_score = total_score self.node_score = node_score
self.local_node_score = local_node_score
self.is_end = is_end self.is_end = is_end
self.next = {} self.next = {}
self.fail = None self.fail = None
@ -72,7 +86,15 @@ class ContextGraph:
word/phrase will have larger bonus score, they have to be matched though). word/phrase will have larger bonus score, they have to be matched though).
""" """
self.context_score = context_score self.context_score = context_score
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False) self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
is_end=False,
)
self.root.fail = self.root self.root.fail = self.root
def _fill_fail(self): def _fill_fail(self):
@ -81,12 +103,12 @@ class ContextGraph:
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm. details of the algorithm.
""" """
queue = [] queue = deque()
for token, node in self.root.next.items(): for token, node in self.root.next.items():
node.fail = self.root node.fail = self.root
queue.append(node) queue.append(node)
while queue: while queue:
current_node = queue.pop(0) current_node = queue.popleft()
for token, node in current_node.next.items(): for token, node in current_node.next.items():
fail = current_node.fail fail = current_node.fail
if token in fail.next: if token in fail.next:
@ -102,7 +124,7 @@ class ContextGraph:
node.fail = fail node.fail = fail
queue.append(node) queue.append(node)
def build_context_graph(self, token_ids: List[List[int]]): def build(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list. """Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc It first build a trie from the given token lists, then fill the fail arc
for each trie node. for each trie node.
@ -120,13 +142,17 @@ class ContextGraph:
node = self.root node = self.root
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if token not in node.next: if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node.next[token] = ContextState( node.next[token] = ContextState(
id=self.num_nodes,
token=token, token=token,
score=self.context_score, token_score=self.context_score,
# The total score is the accumulated score from root to current node, node_score=node.node_score + self.context_score,
# it will be used to calculate the score of fail arc later. local_node_score=0
total_score=node.total_score + self.context_score, if is_end
is_end=i == len(tokens) - 1, else (node.local_node_score + self.context_score),
is_end=is_end,
) )
node = node.next[token] node = node.next[token]
self._fill_fail() self._fill_fail()
@ -138,7 +164,7 @@ class ContextGraph:
Args: Args:
state: state:
The given state (trie node) to start. The given token containing trie node to start.
token: token:
The given token. The given token.
@ -148,9 +174,7 @@ class ContextGraph:
# token matched # token matched
if token in state.next: if token in state.next:
node = state.next[token] node = state.next[token]
score = node.score score = node.token_score
if node.is_end:
node = self.root
return (score, node) return (score, node)
else: else:
# token not matched # token not matched
@ -164,10 +188,9 @@ class ContextGraph:
if token in node.next: if token in node.next:
node = node.next[token] node = node.next[token]
# The score of the fail arc
score = node.total_score - state.total_score # The score of the fail path
if node.is_end: score = node.node_score - state.local_node_score
node = self.root
return (score, node) return (score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
@ -185,49 +208,161 @@ class ContextGraph:
to root. The next state is always root. to root. The next state is always root.
""" """
# The score of the fail arc # The score of the fail arc
score = self.root.total_score - state.total_score score = self.root.node_score - state.local_node_score
if state.is_end:
score = 0
return (score, self.root) return (score, self.root)
def draw(
self,
title: Optional[str] = None,
filename: Optional[str] = "",
symbol_table: Optional[Dict[int, str]] = None,
) -> "Digraph": # noqa
"""Visualize a ContextGraph via graphviz.
Render ContextGraph as an image via graphviz, and return the Digraph object;
and optionally save to file `filename`.
`filename` must have a suffix that graphviz understands, such as
`pdf`, `svg` or `png`.
Note:
You need to install graphviz to use this function::
pip install graphviz
Args:
title:
Title to be displayed in image, e.g. 'A simple FSA example'
filename:
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
'foo.png' (must have a suffix that graphviz understands).
symbol_table:
Map the token ids to symbols.
Returns:
A Diagraph from grahpviz.
"""
try:
import graphviz
except Exception:
print("You cannot use `to_dot` unless the graphviz package is installed.")
raise
graph_attr = {
"rankdir": "LR",
"size": "8.5,11",
"center": "1",
"orientation": "Portrait",
"ranksep": "0.4",
"nodesep": "0.25",
}
if title is not None:
graph_attr["label"] = title
default_node_attr = {
"shape": "circle",
"style": "bold",
"fontsize": "14",
}
final_state_attr = {
"shape": "doublecircle",
"style": "bold",
"fontsize": "14",
}
final_state = -1
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
seen = set()
queue = deque()
queue.append(self.root)
# root id is always 0
dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", label=f"*/0")
seen.add(0)
while len(queue):
current_node = queue.popleft()
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
dot.node(str(node.id), label=label, **default_node_attr)
seen.add(node.id)
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
label = str(token) if symbol_table is None else symbol_table[token]
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
dot.edge(
str(node.id),
str(node.fail.id),
color="red",
)
queue.append(node)
if filename:
_, extension = os.path.splitext(filename)
if extension == "" or extension[0] != ".":
raise ValueError(
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
filename
)
)
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
temp_fn = dot.render(
filename="temp",
directory=tmp_dir,
format=extension[1:],
cleanup=True,
)
shutil.move(temp_fn, filename)
return dot
if __name__ == "__main__": if __name__ == "__main__":
contexts_str = ["HE", "SHE", "HIS", "HERS"] contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
contexts = [] contexts = []
for s in contexts_str: for s in contexts_str:
contexts.append([ord(x) for x in s]) contexts.append([ord(x) for x in s])
context_graph = ContextGraph(context_score=2) context_graph = ContextGraph(context_score=1)
context_graph.build_context_graph(contexts) context_graph.build(contexts)
score, state = context_graph.forward_one_step(context_graph.root, ord("H")) symbol_table = {}
assert score == 2, score for contexts in contexts_str:
assert state.token == ord("H"), state.token for s in contexts:
symbol_table[ord(s)] = s
score, state = context_graph.forward_one_step(state, ord("I")) context_graph.draw(
assert score == 2, score title="Graph for: " + " / ".join(contexts_str),
assert state.token == ord("I"), state.token filename="context_graph.pdf",
symbol_table=symbol_table,
)
score, state = context_graph.forward_one_step(state, ord("S")) queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
assert score == 2, score expected_scores = [7, 6, 3, 2, 5, 3]
assert state.token == -1, state.token for i, query in enumerate(queries):
total_scores = 0
score, state = context_graph.finalize(state) state = context_graph.root
assert score == 0, score for q in query:
assert state.token == -1, state.token score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.forward_one_step(context_graph.root, ord("S")) score, state = context_graph.finalize(state)
assert score == 2, score assert state.token == -1, state.token
assert state.token == ord("S"), state.token total_scores += score
assert total_scores == expected_scores[i], (
score, state = context_graph.forward_one_step(state, ord("H")) total_scores,
assert score == 2, score expected_scores[i],
assert state.token == ord("H"), state.token query,
)
score, state = context_graph.forward_one_step(state, ord("D"))
assert score == -4, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
assert score == 0, score
assert state.token == -1, state.token