add output arc; fix black

This commit is contained in:
pkufool 2023-05-11 19:16:44 +08:00
parent 40a05810dd
commit 949e49eec8
2 changed files with 57 additions and 15 deletions

View File

@ -449,9 +449,7 @@ class LibriSpeechAsrDataModule:
@lru_cache() @lru_cache()
def test_book_cuts(self) -> CutSet: def test_book_cuts(self) -> CutSet:
logging.info("About to get test-books cuts") logging.info("About to get test-books cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_book_test_cuts(self) -> CutSet: def test_book_test_cuts(self) -> CutSet:

View File

@ -61,6 +61,7 @@ class ContextState:
self.is_end = is_end self.is_end = is_end
self.next = {} self.next = {}
self.fail = None self.fail = None
self.output = None
class ContextGraph: class ContextGraph:
@ -97,7 +98,7 @@ class ContextGraph:
) )
self.root.fail = self.root self.root.fail = self.root
def _fill_fail(self): def _fill_fail_output(self):
"""This function fills the fail arc for each trie node, it can be computed """This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root. in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
@ -122,6 +123,14 @@ class ContextGraph:
if token in fail.next: if token in fail.next:
fail = fail.next[token] fail = fail.next[token]
node.fail = fail node.fail = fail
# fill the output arc
output = node.fail
while not output.is_end:
output = output.fail
if output.token == -1: # root
output = None
break
node.output = output
queue.append(node) queue.append(node)
def build(self, token_ids: List[List[int]]): def build(self, token_ids: List[List[int]]):
@ -155,7 +164,7 @@ class ContextGraph:
is_end=is_end, is_end=is_end,
) )
node = node.next[token] node = node.next[token]
self._fill_fail() self._fill_fail_output()
def forward_one_step( def forward_one_step(
self, state: ContextState, token: int self, state: ContextState, token: int
@ -171,11 +180,14 @@ class ContextGraph:
Returns: Returns:
Return a tuple of score and next state. Return a tuple of score and next state.
""" """
node = None
score = 0
# token matched # token matched
if token in state.next: if token in state.next:
node = state.next[token] node = state.next[token]
score = node.token_score score = node.token_score
return (score, node) if state.is_end:
score += state.node_score
else: else:
# token not matched # token not matched
# We will trace along the fail arc until it matches the token or reaching # We will trace along the fail arc until it matches the token or reaching
@ -191,7 +203,13 @@ class ContextGraph:
# The score of the fail path # The score of the fail path
score = node.node_score - state.local_node_score score = node.node_score - state.local_node_score
return (score, node) assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize """When reaching the end of the decoded sequence, we need to finalize
@ -208,7 +226,9 @@ class ContextGraph:
to root. The next state is always root. to root. The next state is always root.
""" """
# The score of the fail arc # The score of the fail arc
score = self.root.node_score - state.local_node_score score = -state.node_score
if state.is_end:
score = 0
return (score, self.root) return (score, self.root)
def draw( def draw(
@ -279,7 +299,7 @@ class ContextGraph:
queue.append(self.root) queue.append(self.root)
# root id is always 0 # root id is always 0
dot.node("0", label="0", **default_node_attr) dot.node("0", label="0", **default_node_attr)
dot.edge("0", "0", label=f"*/0") dot.edge("0", "0", color="red")
seen.add(0) seen.add(0)
while len(queue): while len(queue):
@ -304,6 +324,12 @@ class ContextGraph:
str(node.fail.id), str(node.fail.id),
color="red", color="red",
) )
if node.output is not None:
dot.edge(
str(node.id),
str(node.output.id),
color="green",
)
queue.append(node) queue.append(node)
if filename: if filename:
@ -331,7 +357,17 @@ class ContextGraph:
if __name__ == "__main__": if __name__ == "__main__":
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"] contexts_str = [
"S",
"HE",
"SHE",
"SHELL",
"HIS",
"HERS",
"HELLO",
"THIS",
"THEM",
]
contexts = [] contexts = []
for s in contexts_str: for s in contexts_str:
contexts.append([ord(x) for x in s]) contexts.append([ord(x) for x in s])
@ -350,9 +386,17 @@ if __name__ == "__main__":
symbol_table=symbol_table, symbol_table=symbol_table,
) )
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"] queries = {
expected_scores = [7, 6, 3, 2, 5, 3] "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
for i, query in enumerate(queries): "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0 total_scores = 0
state = context_graph.root state = context_graph.root
for q in query: for q in query:
@ -361,8 +405,8 @@ if __name__ == "__main__":
score, state = context_graph.finalize(state) score, state = context_graph.finalize(state)
assert state.token == -1, state.token assert state.token == -1, state.token
total_scores += score total_scores += score
assert total_scores == expected_scores[i], ( assert total_scores == expected_score, (
total_scores, total_scores,
expected_scores[i], expected_score,
query, query,
) )