mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add output arc; fix black
This commit is contained in:
parent
40a05810dd
commit
949e49eec8
@ -449,9 +449,7 @@ class LibriSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def test_book_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_book_test_cuts(self) -> CutSet:
|
||||
|
||||
@ -61,6 +61,7 @@ class ContextState:
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
self.output = None
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
@ -97,7 +98,7 @@ class ContextGraph:
|
||||
)
|
||||
self.root.fail = self.root
|
||||
|
||||
def _fill_fail(self):
|
||||
def _fill_fail_output(self):
|
||||
"""This function fills the fail arc for each trie node, it can be computed
|
||||
in linear time by performing a breadth-first search starting from the root.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||
@ -122,6 +123,14 @@ class ContextGraph:
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
node.fail = fail
|
||||
# fill the output arc
|
||||
output = node.fail
|
||||
while not output.is_end:
|
||||
output = output.fail
|
||||
if output.token == -1: # root
|
||||
output = None
|
||||
break
|
||||
node.output = output
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
@ -155,7 +164,7 @@ class ContextGraph:
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail()
|
||||
self._fill_fail_output()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
@ -171,11 +180,14 @@ class ContextGraph:
|
||||
Returns:
|
||||
Return a tuple of score and next state.
|
||||
"""
|
||||
node = None
|
||||
score = 0
|
||||
# token matched
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.token_score
|
||||
return (score, node)
|
||||
if state.is_end:
|
||||
score += state.node_score
|
||||
else:
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
@ -191,7 +203,13 @@ class ContextGraph:
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
return (score, node)
|
||||
assert node is not None
|
||||
matched_score = 0
|
||||
output = node.output
|
||||
while output is not None:
|
||||
matched_score += output.node_score
|
||||
output = output.output
|
||||
return (score + matched_score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
@ -208,7 +226,9 @@ class ContextGraph:
|
||||
to root. The next state is always root.
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = self.root.node_score - state.local_node_score
|
||||
score = -state.node_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
@ -279,7 +299,7 @@ class ContextGraph:
|
||||
queue.append(self.root)
|
||||
# root id is always 0
|
||||
dot.node("0", label="0", **default_node_attr)
|
||||
dot.edge("0", "0", label=f"*/0")
|
||||
dot.edge("0", "0", color="red")
|
||||
seen.add(0)
|
||||
|
||||
while len(queue):
|
||||
@ -304,6 +324,12 @@ class ContextGraph:
|
||||
str(node.fail.id),
|
||||
color="red",
|
||||
)
|
||||
if node.output is not None:
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.output.id),
|
||||
color="green",
|
||||
)
|
||||
queue.append(node)
|
||||
|
||||
if filename:
|
||||
@ -331,7 +357,17 @@ class ContextGraph:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
|
||||
contexts_str = [
|
||||
"S",
|
||||
"HE",
|
||||
"SHE",
|
||||
"SHELL",
|
||||
"HIS",
|
||||
"HERS",
|
||||
"HELLO",
|
||||
"THIS",
|
||||
"THEM",
|
||||
]
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
@ -350,9 +386,17 @@ if __name__ == "__main__":
|
||||
symbol_table=symbol_table,
|
||||
)
|
||||
|
||||
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
|
||||
expected_scores = [7, 6, 3, 2, 5, 3]
|
||||
for i, query in enumerate(queries):
|
||||
queries = {
|
||||
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 6, # "S", "SHE", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 7, # "HE", "HELLO"
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
for query, expected_score in queries.items():
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
@ -361,8 +405,8 @@ if __name__ == "__main__":
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert total_scores == expected_scores[i], (
|
||||
assert total_scores == expected_score, (
|
||||
total_scores,
|
||||
expected_scores[i],
|
||||
expected_score,
|
||||
query,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user