fix context graph

This commit is contained in:
pkufool 2023-08-28 15:52:46 +08:00
parent e90563cdff
commit 306380f839

View File

@ -29,7 +29,7 @@ class ContextState:
token: int, token: int,
token_score: float, token_score: float,
node_score: float, node_score: float,
local_node_score: float, output_score: float,
is_end: bool, is_end: bool,
): ):
"""Create a ContextState. """Create a ContextState.
@ -40,16 +40,15 @@ class ContextState:
The id of the root node is always 0. The id of the root node is always 0.
token: token:
The token id. The token id.
score: token_score:
The bonus for each token during decoding, which will hopefully The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search. boost the token up to survive beam search.
node_score: node_score:
The accumulated bonus from root of graph to current node, it will be The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc. used to calculate the score for fail arc.
local_node_score: output_score:
The accumulated bonus from last ``end_node``(node with is_end true) The total scores of matched phrases, sum of the node_score of all
to current_node, it will be used to calculate the score for fail arc. the output node for current node.
Node: The local_node_score of a ``end_node`` is 0.
is_end: is_end:
True if current token is the end of a context. True if current token is the end of a context.
""" """
@ -57,7 +56,7 @@ class ContextState:
self.token = token self.token = token
self.token_score = token_score self.token_score = token_score
self.node_score = node_score self.node_score = node_score
self.local_node_score = local_node_score self.output_score = output_score
self.is_end = is_end self.is_end = is_end
self.next = {} self.next = {}
self.fail = None self.fail = None
@ -93,7 +92,7 @@ class ContextGraph:
token=-1, token=-1,
token_score=0, token_score=0,
node_score=0, node_score=0,
local_node_score=0, output_score=0,
is_end=False, is_end=False,
) )
self.root.fail = self.root self.root.fail = self.root
@ -131,6 +130,7 @@ class ContextGraph:
output = None output = None
break break
node.output = output node.output = output
node.output_score += 0 if output is None else output.output_score
queue.append(node) queue.append(node)
def build(self, token_ids: List[List[int]]): def build(self, token_ids: List[List[int]]):
@ -153,14 +153,13 @@ class ContextGraph:
if token not in node.next: if token not in node.next:
self.num_nodes += 1 self.num_nodes += 1
is_end = i == len(tokens) - 1 is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState( node.next[token] = ContextState(
id=self.num_nodes, id=self.num_nodes,
token=token, token=token,
token_score=self.context_score, token_score=self.context_score,
node_score=node.node_score + self.context_score, node_score=node_score,
local_node_score=0 output_score=node_score if is_end else 0,
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end, is_end=is_end,
) )
node = node.next[token] node = node.next[token]
@ -186,8 +185,6 @@ class ContextGraph:
if token in state.next: if token in state.next:
node = state.next[token] node = state.next[token]
score = node.token_score score = node.token_score
if state.is_end:
score += state.node_score
else: else:
# token not matched # token not matched
# We will trace along the fail arc until it matches the token or reaching # We will trace along the fail arc until it matches the token or reaching
@ -202,14 +199,9 @@ class ContextGraph:
node = node.next[token] node = node.next[token]
# The score of the fail path # The score of the fail path
score = node.node_score - state.local_node_score score = node.node_score - state.node_score
assert node is not None assert node is not None
matched_score = 0 return (score + node.output_score, node)
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize """When reaching the end of the decoded sequence, we need to finalize
@ -227,8 +219,6 @@ class ContextGraph:
""" """
# The score of the fail arc # The score of the fail arc
score = -state.node_score score = -state.node_score
if state.is_end:
score = 0
return (score, self.root) return (score, self.root)
def draw( def draw(
@ -307,10 +297,8 @@ class ContextGraph:
for token, node in current_node.next.items(): for token, node in current_node.next.items():
if node.id not in seen: if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip( output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
"0" label = f"{node.id}/({node_score}, {output_score})"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end: if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr) dot.node(str(node.id), label=label, **final_state_attr)
else: else:
@ -391,6 +379,7 @@ if __name__ == "__main__":
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE" "HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE" "SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE" "HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO" "HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S" "DHRHISQ": 4, # "HIS", "S"