fix some bugs

This commit is contained in:
pkufool 2023-05-10 00:00:59 +08:00
parent 2e7e7875f5
commit 62557a1564

View File

@ -20,7 +20,13 @@ from typing import Dict, List, Tuple
class ContextState:
"""The state in ContextGraph"""
def __init__(self, token: int, score: float, total_score: float, is_end: bool):
def __init__(
self,
token: int,
score: float,
total_score: float,
is_end: bool,
):
"""Create a ContextState.
Args:
@ -81,11 +87,18 @@ class ContextGraph:
queue.append(node)
while queue:
current_node = queue.pop(0)
current_fail = current_node.fail
for token, node in current_node.next.items():
fail = current_fail
if token in current_fail.next:
fail = current_fail.next[token]
fail = current_node.fail
if token in fail.next:
fail = fail.next[token]
else:
fail = fail.fail
while token not in fail.next:
fail = fail.fail
if fail.token == -1: # root
break
if token in fail.next:
fail = fail.next[token]
node.fail = fail
queue.append(node)
@ -116,7 +129,7 @@ class ContextGraph:
is_end=i == len(tokens) - 1,
)
node = node.next[token]
self._fill_fail()
self._fill_fail()
def forward_one_step(
self, state: ContextState, token: int
@ -136,8 +149,6 @@ class ContextGraph:
if token in state.next:
node = state.next[token]
score = node.score
# if the matched node is the end of a word/phrase, we will start
# from the root for next token.
if node.is_end:
node = self.root
return (score, node)