Fixes to forward_one_step; add draw to context graph

This commit is contained in:
pkufool 2023-05-11 14:23:48 +08:00
parent 62557a1564
commit 40a05810dd
3 changed files with 195 additions and 60 deletions

View File

@ -919,7 +919,7 @@ def main():
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(sp.encode(contexts))
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:

View File

@ -855,7 +855,7 @@ def main():
for line in open(params.context_file).readlines():
contexts.append(graph_compiler.texts_to_ids(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts)
context_graph.build(contexts)
else:
context_graph = None
else:

View File

@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
import os
import shutil
from collections import deque
from typing import Dict, List, Optional, Tuple
class ContextState:
@ -22,28 +25,39 @@ class ContextState:
def __init__(
self,
id: int,
token: int,
score: float,
total_score: float,
token_score: float,
node_score: float,
local_node_score: float,
is_end: bool,
):
"""Create a ContextState.
Args:
id:
The node id, only for visualization now. A node is in [0, graph.num_nodes).
The id of the root node is always 0.
token:
The token id.
score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
total_score:
node_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
is_end:
True if current token is the end of a context.
"""
self.id = id
self.token = token
self.score = score
self.total_score = total_score
self.token_score = token_score
self.node_score = node_score
self.local_node_score = local_node_score
self.is_end = is_end
self.next = {}
self.fail = None
@ -72,7 +86,15 @@ class ContextGraph:
word/phrase will have larger bonus score, they have to be matched though).
"""
self.context_score = context_score
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
self.num_nodes = 0
self.root = ContextState(
id=self.num_nodes,
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
is_end=False,
)
self.root.fail = self.root
def _fill_fail(self):
@ -81,12 +103,12 @@ class ContextGraph:
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm.
"""
queue = []
queue = deque()
for token, node in self.root.next.items():
node.fail = self.root
queue.append(node)
while queue:
current_node = queue.pop(0)
current_node = queue.popleft()
for token, node in current_node.next.items():
fail = current_node.fail
if token in fail.next:
@ -102,7 +124,7 @@ class ContextGraph:
node.fail = fail
queue.append(node)
def build_context_graph(self, token_ids: List[List[int]]):
def build(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@ -120,13 +142,17 @@ class ContextGraph:
node = self.root
for i, token in enumerate(tokens):
if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
score=self.context_score,
# The total score is the accumulated score from root to current node,
# it will be used to calculate the score of fail arc later.
total_score=node.total_score + self.context_score,
is_end=i == len(tokens) - 1,
token_score=self.context_score,
node_score=node.node_score + self.context_score,
local_node_score=0
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end,
)
node = node.next[token]
self._fill_fail()
@ -138,7 +164,7 @@ class ContextGraph:
Args:
state:
The given state (trie node) to start.
The given token containing trie node to start.
token:
The given token.
@ -148,9 +174,7 @@ class ContextGraph:
# token matched
if token in state.next:
node = state.next[token]
score = node.score
if node.is_end:
node = self.root
score = node.token_score
return (score, node)
else:
# token not matched
@ -164,10 +188,9 @@ class ContextGraph:
if token in node.next:
node = node.next[token]
# The score of the fail arc
score = node.total_score - state.total_score
if node.is_end:
node = self.root
# The score of the fail path
score = node.node_score - state.local_node_score
return (score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
@ -185,49 +208,161 @@ class ContextGraph:
to root. The next state is always root.
"""
# The score of the fail arc
score = self.root.total_score - state.total_score
if state.is_end:
score = 0
score = self.root.node_score - state.local_node_score
return (score, self.root)
def draw(
self,
title: Optional[str] = None,
filename: Optional[str] = "",
symbol_table: Optional[Dict[int, str]] = None,
) -> "Digraph": # noqa
"""Visualize a ContextGraph via graphviz.
Render ContextGraph as an image via graphviz, and return the Digraph object;
and optionally save to file `filename`.
`filename` must have a suffix that graphviz understands, such as
`pdf`, `svg` or `png`.
Note:
You need to install graphviz to use this function::
pip install graphviz
Args:
title:
Title to be displayed in image, e.g. 'A simple FSA example'
filename:
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
'foo.png' (must have a suffix that graphviz understands).
symbol_table:
Map the token ids to symbols.
Returns:
A Diagraph from grahpviz.
"""
try:
import graphviz
except Exception:
print("You cannot use `to_dot` unless the graphviz package is installed.")
raise
graph_attr = {
"rankdir": "LR",
"size": "8.5,11",
"center": "1",
"orientation": "Portrait",
"ranksep": "0.4",
"nodesep": "0.25",
}
if title is not None:
graph_attr["label"] = title
default_node_attr = {
"shape": "circle",
"style": "bold",
"fontsize": "14",
}
final_state_attr = {
"shape": "doublecircle",
"style": "bold",
"fontsize": "14",
}
final_state = -1
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
seen = set()
queue = deque()
queue.append(self.root)
# root id is always 0
dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", label=f"*/0")
seen.add(0)
while len(queue):
current_node = queue.popleft()
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
dot.node(str(node.id), label=label, **default_node_attr)
seen.add(node.id)
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
label = str(token) if symbol_table is None else symbol_table[token]
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
dot.edge(
str(node.id),
str(node.fail.id),
color="red",
)
queue.append(node)
if filename:
_, extension = os.path.splitext(filename)
if extension == "" or extension[0] != ".":
raise ValueError(
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
filename
)
)
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
temp_fn = dot.render(
filename="temp",
directory=tmp_dir,
format=extension[1:],
cleanup=True,
)
shutil.move(temp_fn, filename)
return dot
if __name__ == "__main__":
contexts_str = ["HE", "SHE", "HIS", "HERS"]
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
context_graph = ContextGraph(context_score=2)
context_graph.build_context_graph(contexts)
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
assert score == 2, score
assert state.token == ord("H"), state.token
symbol_table = {}
for contexts in contexts_str:
for s in contexts:
symbol_table[ord(s)] = s
score, state = context_graph.forward_one_step(state, ord("I"))
assert score == 2, score
assert state.token == ord("I"), state.token
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf",
symbol_table=symbol_table,
)
score, state = context_graph.forward_one_step(state, ord("S"))
assert score == 2, score
assert state.token == -1, state.token
score, state = context_graph.finalize(state)
assert score == 0, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
assert score == 2, score
assert state.token == ord("S"), state.token
score, state = context_graph.forward_one_step(state, ord("H"))
assert score == 2, score
assert state.token == ord("H"), state.token
score, state = context_graph.forward_one_step(state, ord("D"))
assert score == -4, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
assert score == 0, score
assert state.token == -1, state.token
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
expected_scores = [7, 6, 3, 2, 5, 3]
for i, query in enumerate(queries):
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_scores[i], (
total_scores,
expected_scores[i],
query,
)