add output arc; fix black

This commit is contained in:
pkufool 2023-05-11 19:16:44 +08:00
parent 40a05810dd
commit 949e49eec8
2 changed files with 57 additions and 15 deletions

View File

@ -449,9 +449,7 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def test_book_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
@lru_cache()
def test_book_test_cuts(self) -> CutSet:

View File

@ -61,6 +61,7 @@ class ContextState:
self.is_end = is_end
self.next = {}
self.fail = None
self.output = None
class ContextGraph:
@ -97,7 +98,7 @@ class ContextGraph:
)
self.root.fail = self.root
def _fill_fail(self):
def _fill_fail_output(self):
"""This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
@ -122,6 +123,14 @@ class ContextGraph:
if token in fail.next:
fail = fail.next[token]
node.fail = fail
# fill the output arc
output = node.fail
while not output.is_end:
output = output.fail
if output.token == -1: # root
output = None
break
node.output = output
queue.append(node)
def build(self, token_ids: List[List[int]]):
@ -155,7 +164,7 @@ class ContextGraph:
is_end=is_end,
)
node = node.next[token]
self._fill_fail()
self._fill_fail_output()
def forward_one_step(
self, state: ContextState, token: int
@ -171,11 +180,14 @@ class ContextGraph:
Returns:
Return a tuple of score and next state.
"""
node = None
score = 0
# token matched
if token in state.next:
node = state.next[token]
score = node.token_score
return (score, node)
if state.is_end:
score += state.node_score
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
@ -191,7 +203,13 @@ class ContextGraph:
# The score of the fail path
score = node.node_score - state.local_node_score
return (score, node)
assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
@ -208,7 +226,9 @@ class ContextGraph:
to root. The next state is always root.
"""
# The score of the fail arc
score = self.root.node_score - state.local_node_score
score = -state.node_score
if state.is_end:
score = 0
return (score, self.root)
def draw(
@ -279,7 +299,7 @@ class ContextGraph:
queue.append(self.root)
# root id is always 0
dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", label=f"*/0")
dot.edge("0", "0", color="red")
seen.add(0)
while len(queue):
@ -304,6 +324,12 @@ class ContextGraph:
str(node.fail.id),
color="red",
)
if node.output is not None:
dot.edge(
str(node.id),
str(node.output.id),
color="green",
)
queue.append(node)
if filename:
@ -331,7 +357,17 @@ class ContextGraph:
if __name__ == "__main__":
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
contexts_str = [
"S",
"HE",
"SHE",
"SHELL",
"HIS",
"HERS",
"HELLO",
"THIS",
"THEM",
]
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
@ -350,9 +386,17 @@ if __name__ == "__main__":
symbol_table=symbol_table,
)
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
expected_scores = [7, 6, 3, 2, 5, 3]
for i, query in enumerate(queries):
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
@ -361,8 +405,8 @@ if __name__ == "__main__":
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_scores[i], (
assert total_scores == expected_score, (
total_scores,
expected_scores[i],
expected_score,
query,
)