mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add output arc; fix black
This commit is contained in:
parent
40a05810dd
commit
949e49eec8
@ -449,9 +449,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_book_cuts(self) -> CutSet:
|
def test_book_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-books cuts")
|
logging.info("About to get test-books cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "libri_books_feats.jsonl.gz")
|
||||||
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_book_test_cuts(self) -> CutSet:
|
def test_book_test_cuts(self) -> CutSet:
|
||||||
|
|||||||
@ -61,6 +61,7 @@ class ContextState:
|
|||||||
self.is_end = is_end
|
self.is_end = is_end
|
||||||
self.next = {}
|
self.next = {}
|
||||||
self.fail = None
|
self.fail = None
|
||||||
|
self.output = None
|
||||||
|
|
||||||
|
|
||||||
class ContextGraph:
|
class ContextGraph:
|
||||||
@ -97,7 +98,7 @@ class ContextGraph:
|
|||||||
)
|
)
|
||||||
self.root.fail = self.root
|
self.root.fail = self.root
|
||||||
|
|
||||||
def _fill_fail(self):
|
def _fill_fail_output(self):
|
||||||
"""This function fills the fail arc for each trie node, it can be computed
|
"""This function fills the fail arc for each trie node, it can be computed
|
||||||
in linear time by performing a breadth-first search starting from the root.
|
in linear time by performing a breadth-first search starting from the root.
|
||||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||||
@ -122,6 +123,14 @@ class ContextGraph:
|
|||||||
if token in fail.next:
|
if token in fail.next:
|
||||||
fail = fail.next[token]
|
fail = fail.next[token]
|
||||||
node.fail = fail
|
node.fail = fail
|
||||||
|
# fill the output arc
|
||||||
|
output = node.fail
|
||||||
|
while not output.is_end:
|
||||||
|
output = output.fail
|
||||||
|
if output.token == -1: # root
|
||||||
|
output = None
|
||||||
|
break
|
||||||
|
node.output = output
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
def build(self, token_ids: List[List[int]]):
|
def build(self, token_ids: List[List[int]]):
|
||||||
@ -155,7 +164,7 @@ class ContextGraph:
|
|||||||
is_end=is_end,
|
is_end=is_end,
|
||||||
)
|
)
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
self._fill_fail()
|
self._fill_fail_output()
|
||||||
|
|
||||||
def forward_one_step(
|
def forward_one_step(
|
||||||
self, state: ContextState, token: int
|
self, state: ContextState, token: int
|
||||||
@ -171,11 +180,14 @@ class ContextGraph:
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tuple of score and next state.
|
Return a tuple of score and next state.
|
||||||
"""
|
"""
|
||||||
|
node = None
|
||||||
|
score = 0
|
||||||
# token matched
|
# token matched
|
||||||
if token in state.next:
|
if token in state.next:
|
||||||
node = state.next[token]
|
node = state.next[token]
|
||||||
score = node.token_score
|
score = node.token_score
|
||||||
return (score, node)
|
if state.is_end:
|
||||||
|
score += state.node_score
|
||||||
else:
|
else:
|
||||||
# token not matched
|
# token not matched
|
||||||
# We will trace along the fail arc until it matches the token or reaching
|
# We will trace along the fail arc until it matches the token or reaching
|
||||||
@ -191,7 +203,13 @@ class ContextGraph:
|
|||||||
|
|
||||||
# The score of the fail path
|
# The score of the fail path
|
||||||
score = node.node_score - state.local_node_score
|
score = node.node_score - state.local_node_score
|
||||||
return (score, node)
|
assert node is not None
|
||||||
|
matched_score = 0
|
||||||
|
output = node.output
|
||||||
|
while output is not None:
|
||||||
|
matched_score += output.node_score
|
||||||
|
output = output.output
|
||||||
|
return (score + matched_score, node)
|
||||||
|
|
||||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
"""When reaching the end of the decoded sequence, we need to finalize
|
"""When reaching the end of the decoded sequence, we need to finalize
|
||||||
@ -208,7 +226,9 @@ class ContextGraph:
|
|||||||
to root. The next state is always root.
|
to root. The next state is always root.
|
||||||
"""
|
"""
|
||||||
# The score of the fail arc
|
# The score of the fail arc
|
||||||
score = self.root.node_score - state.local_node_score
|
score = -state.node_score
|
||||||
|
if state.is_end:
|
||||||
|
score = 0
|
||||||
return (score, self.root)
|
return (score, self.root)
|
||||||
|
|
||||||
def draw(
|
def draw(
|
||||||
@ -279,7 +299,7 @@ class ContextGraph:
|
|||||||
queue.append(self.root)
|
queue.append(self.root)
|
||||||
# root id is always 0
|
# root id is always 0
|
||||||
dot.node("0", label="0", **default_node_attr)
|
dot.node("0", label="0", **default_node_attr)
|
||||||
dot.edge("0", "0", label=f"*/0")
|
dot.edge("0", "0", color="red")
|
||||||
seen.add(0)
|
seen.add(0)
|
||||||
|
|
||||||
while len(queue):
|
while len(queue):
|
||||||
@ -304,6 +324,12 @@ class ContextGraph:
|
|||||||
str(node.fail.id),
|
str(node.fail.id),
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
if node.output is not None:
|
||||||
|
dot.edge(
|
||||||
|
str(node.id),
|
||||||
|
str(node.output.id),
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
@ -331,7 +357,17 @@ class ContextGraph:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
|
contexts_str = [
|
||||||
|
"S",
|
||||||
|
"HE",
|
||||||
|
"SHE",
|
||||||
|
"SHELL",
|
||||||
|
"HIS",
|
||||||
|
"HERS",
|
||||||
|
"HELLO",
|
||||||
|
"THIS",
|
||||||
|
"THEM",
|
||||||
|
]
|
||||||
contexts = []
|
contexts = []
|
||||||
for s in contexts_str:
|
for s in contexts_str:
|
||||||
contexts.append([ord(x) for x in s])
|
contexts.append([ord(x) for x in s])
|
||||||
@ -350,9 +386,17 @@ if __name__ == "__main__":
|
|||||||
symbol_table=symbol_table,
|
symbol_table=symbol_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
|
queries = {
|
||||||
expected_scores = [7, 6, 3, 2, 5, 3]
|
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||||
for i, query in enumerate(queries):
|
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||||
|
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||||
|
"SHED": 6, # "S", "SHE", "HE"
|
||||||
|
"HELL": 2, # "HE"
|
||||||
|
"HELLO": 7, # "HE", "HELLO"
|
||||||
|
"DHRHISQ": 4, # "HIS", "S"
|
||||||
|
"THEN": 2, # "HE"
|
||||||
|
}
|
||||||
|
for query, expected_score in queries.items():
|
||||||
total_scores = 0
|
total_scores = 0
|
||||||
state = context_graph.root
|
state = context_graph.root
|
||||||
for q in query:
|
for q in query:
|
||||||
@ -361,8 +405,8 @@ if __name__ == "__main__":
|
|||||||
score, state = context_graph.finalize(state)
|
score, state = context_graph.finalize(state)
|
||||||
assert state.token == -1, state.token
|
assert state.token == -1, state.token
|
||||||
total_scores += score
|
total_scores += score
|
||||||
assert total_scores == expected_scores[i], (
|
assert total_scores == expected_score, (
|
||||||
total_scores,
|
total_scores,
|
||||||
expected_scores[i],
|
expected_score,
|
||||||
query,
|
query,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user