fix context graph

This commit is contained in:
pkufool 2023-08-28 15:52:46 +08:00
parent e90563cdff
commit 306380f839

View File

@ -29,7 +29,7 @@ class ContextState:
token: int,
token_score: float,
node_score: float,
local_node_score: float,
output_score: float,
is_end: bool,
):
"""Create a ContextState.
@ -40,16 +40,15 @@ class ContextState:
The id of the root node is always 0.
token:
The token id.
score:
token_score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
node_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
output_score:
The total scores of matched phrases, sum of the node_score of all
the output node for current node.
is_end:
True if current token is the end of a context.
"""
@ -57,7 +56,7 @@ class ContextState:
self.token = token
self.token_score = token_score
self.node_score = node_score
self.local_node_score = local_node_score
self.output_score = output_score
self.is_end = is_end
self.next = {}
self.fail = None
@ -93,7 +92,7 @@ class ContextGraph:
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
output_score=0,
is_end=False,
)
self.root.fail = self.root
@ -131,6 +130,7 @@ class ContextGraph:
output = None
break
node.output = output
node.output_score += 0 if output is None else output.output_score
queue.append(node)
def build(self, token_ids: List[List[int]]):
@ -153,14 +153,13 @@ class ContextGraph:
if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node.node_score + self.context_score,
local_node_score=0
if is_end
else (node.local_node_score + self.context_score),
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node = node.next[token]
@ -186,8 +185,6 @@ class ContextGraph:
if token in state.next:
node = state.next[token]
score = node.token_score
if state.is_end:
score += state.node_score
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
@ -202,14 +199,9 @@ class ContextGraph:
node = node.next[token]
# The score of the fail path
score = node.node_score - state.local_node_score
score = node.node_score - state.node_score
assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
return (score + node.output_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
@ -227,8 +219,6 @@ class ContextGraph:
"""
# The score of the fail arc
score = -state.node_score
if state.is_end:
score = 0
return (score, self.root)
def draw(
@ -307,10 +297,8 @@ class ContextGraph:
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
label = f"{node.id}/({node_score}, {output_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
@ -391,6 +379,7 @@ if __name__ == "__main__":
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"