mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes to forward_one_step; add draw to context graph
This commit is contained in:
parent
62557a1564
commit
40a05810dd
@ -919,7 +919,7 @@ def main():
|
|||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append(line.strip())
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build_context_graph(sp.encode(contexts))
|
context_graph.build(sp.encode(contexts))
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -855,7 +855,7 @@ def main():
|
|||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(graph_compiler.texts_to_ids(line.strip()))
|
contexts.append(graph_compiler.texts_to_ids(line.strip()))
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build_context_graph(contexts)
|
context_graph.build(contexts)
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -14,7 +14,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Dict, List, Tuple
|
import os
|
||||||
|
import shutil
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class ContextState:
|
class ContextState:
|
||||||
@ -22,28 +25,39 @@ class ContextState:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
id: int,
|
||||||
token: int,
|
token: int,
|
||||||
score: float,
|
token_score: float,
|
||||||
total_score: float,
|
node_score: float,
|
||||||
|
local_node_score: float,
|
||||||
is_end: bool,
|
is_end: bool,
|
||||||
):
|
):
|
||||||
"""Create a ContextState.
|
"""Create a ContextState.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
id:
|
||||||
|
The node id, only for visualization now. A node is in [0, graph.num_nodes).
|
||||||
|
The id of the root node is always 0.
|
||||||
token:
|
token:
|
||||||
The token id.
|
The token id.
|
||||||
score:
|
score:
|
||||||
The bonus for each token during decoding, which will hopefully
|
The bonus for each token during decoding, which will hopefully
|
||||||
boost the token up to survive beam search.
|
boost the token up to survive beam search.
|
||||||
total_score:
|
node_score:
|
||||||
The accumulated bonus from root of graph to current node, it will be
|
The accumulated bonus from root of graph to current node, it will be
|
||||||
used to calculate the score for fail arc.
|
used to calculate the score for fail arc.
|
||||||
|
local_node_score:
|
||||||
|
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||||
|
to current_node, it will be used to calculate the score for fail arc.
|
||||||
|
Node: The local_node_score of a ``end_node`` is 0.
|
||||||
is_end:
|
is_end:
|
||||||
True if current token is the end of a context.
|
True if current token is the end of a context.
|
||||||
"""
|
"""
|
||||||
|
self.id = id
|
||||||
self.token = token
|
self.token = token
|
||||||
self.score = score
|
self.token_score = token_score
|
||||||
self.total_score = total_score
|
self.node_score = node_score
|
||||||
|
self.local_node_score = local_node_score
|
||||||
self.is_end = is_end
|
self.is_end = is_end
|
||||||
self.next = {}
|
self.next = {}
|
||||||
self.fail = None
|
self.fail = None
|
||||||
@ -72,7 +86,15 @@ class ContextGraph:
|
|||||||
word/phrase will have larger bonus score, they have to be matched though).
|
word/phrase will have larger bonus score, they have to be matched though).
|
||||||
"""
|
"""
|
||||||
self.context_score = context_score
|
self.context_score = context_score
|
||||||
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
|
self.num_nodes = 0
|
||||||
|
self.root = ContextState(
|
||||||
|
id=self.num_nodes,
|
||||||
|
token=-1,
|
||||||
|
token_score=0,
|
||||||
|
node_score=0,
|
||||||
|
local_node_score=0,
|
||||||
|
is_end=False,
|
||||||
|
)
|
||||||
self.root.fail = self.root
|
self.root.fail = self.root
|
||||||
|
|
||||||
def _fill_fail(self):
|
def _fill_fail(self):
|
||||||
@ -81,12 +103,12 @@ class ContextGraph:
|
|||||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||||
details of the algorithm.
|
details of the algorithm.
|
||||||
"""
|
"""
|
||||||
queue = []
|
queue = deque()
|
||||||
for token, node in self.root.next.items():
|
for token, node in self.root.next.items():
|
||||||
node.fail = self.root
|
node.fail = self.root
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
while queue:
|
while queue:
|
||||||
current_node = queue.pop(0)
|
current_node = queue.popleft()
|
||||||
for token, node in current_node.next.items():
|
for token, node in current_node.next.items():
|
||||||
fail = current_node.fail
|
fail = current_node.fail
|
||||||
if token in fail.next:
|
if token in fail.next:
|
||||||
@ -102,7 +124,7 @@ class ContextGraph:
|
|||||||
node.fail = fail
|
node.fail = fail
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
def build_context_graph(self, token_ids: List[List[int]]):
|
def build(self, token_ids: List[List[int]]):
|
||||||
"""Build the ContextGraph from a list of token list.
|
"""Build the ContextGraph from a list of token list.
|
||||||
It first build a trie from the given token lists, then fill the fail arc
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
for each trie node.
|
for each trie node.
|
||||||
@ -120,13 +142,17 @@ class ContextGraph:
|
|||||||
node = self.root
|
node = self.root
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if token not in node.next:
|
if token not in node.next:
|
||||||
|
self.num_nodes += 1
|
||||||
|
is_end = i == len(tokens) - 1
|
||||||
node.next[token] = ContextState(
|
node.next[token] = ContextState(
|
||||||
|
id=self.num_nodes,
|
||||||
token=token,
|
token=token,
|
||||||
score=self.context_score,
|
token_score=self.context_score,
|
||||||
# The total score is the accumulated score from root to current node,
|
node_score=node.node_score + self.context_score,
|
||||||
# it will be used to calculate the score of fail arc later.
|
local_node_score=0
|
||||||
total_score=node.total_score + self.context_score,
|
if is_end
|
||||||
is_end=i == len(tokens) - 1,
|
else (node.local_node_score + self.context_score),
|
||||||
|
is_end=is_end,
|
||||||
)
|
)
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
self._fill_fail()
|
self._fill_fail()
|
||||||
@ -138,7 +164,7 @@ class ContextGraph:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
state:
|
state:
|
||||||
The given state (trie node) to start.
|
The given token containing trie node to start.
|
||||||
token:
|
token:
|
||||||
The given token.
|
The given token.
|
||||||
|
|
||||||
@ -148,9 +174,7 @@ class ContextGraph:
|
|||||||
# token matched
|
# token matched
|
||||||
if token in state.next:
|
if token in state.next:
|
||||||
node = state.next[token]
|
node = state.next[token]
|
||||||
score = node.score
|
score = node.token_score
|
||||||
if node.is_end:
|
|
||||||
node = self.root
|
|
||||||
return (score, node)
|
return (score, node)
|
||||||
else:
|
else:
|
||||||
# token not matched
|
# token not matched
|
||||||
@ -164,10 +188,9 @@ class ContextGraph:
|
|||||||
|
|
||||||
if token in node.next:
|
if token in node.next:
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
# The score of the fail arc
|
|
||||||
score = node.total_score - state.total_score
|
# The score of the fail path
|
||||||
if node.is_end:
|
score = node.node_score - state.local_node_score
|
||||||
node = self.root
|
|
||||||
return (score, node)
|
return (score, node)
|
||||||
|
|
||||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
@ -185,49 +208,161 @@ class ContextGraph:
|
|||||||
to root. The next state is always root.
|
to root. The next state is always root.
|
||||||
"""
|
"""
|
||||||
# The score of the fail arc
|
# The score of the fail arc
|
||||||
score = self.root.total_score - state.total_score
|
score = self.root.node_score - state.local_node_score
|
||||||
if state.is_end:
|
|
||||||
score = 0
|
|
||||||
return (score, self.root)
|
return (score, self.root)
|
||||||
|
|
||||||
|
def draw(
|
||||||
|
self,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
filename: Optional[str] = "",
|
||||||
|
symbol_table: Optional[Dict[int, str]] = None,
|
||||||
|
) -> "Digraph": # noqa
|
||||||
|
|
||||||
|
"""Visualize a ContextGraph via graphviz.
|
||||||
|
|
||||||
|
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||||
|
and optionally save to file `filename`.
|
||||||
|
`filename` must have a suffix that graphviz understands, such as
|
||||||
|
`pdf`, `svg` or `png`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You need to install graphviz to use this function::
|
||||||
|
|
||||||
|
pip install graphviz
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title:
|
||||||
|
Title to be displayed in image, e.g. 'A simple FSA example'
|
||||||
|
filename:
|
||||||
|
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
|
||||||
|
'foo.png' (must have a suffix that graphviz understands).
|
||||||
|
symbol_table:
|
||||||
|
Map the token ids to symbols.
|
||||||
|
Returns:
|
||||||
|
A Diagraph from grahpviz.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import graphviz
|
||||||
|
except Exception:
|
||||||
|
print("You cannot use `to_dot` unless the graphviz package is installed.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
graph_attr = {
|
||||||
|
"rankdir": "LR",
|
||||||
|
"size": "8.5,11",
|
||||||
|
"center": "1",
|
||||||
|
"orientation": "Portrait",
|
||||||
|
"ranksep": "0.4",
|
||||||
|
"nodesep": "0.25",
|
||||||
|
}
|
||||||
|
if title is not None:
|
||||||
|
graph_attr["label"] = title
|
||||||
|
|
||||||
|
default_node_attr = {
|
||||||
|
"shape": "circle",
|
||||||
|
"style": "bold",
|
||||||
|
"fontsize": "14",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_state_attr = {
|
||||||
|
"shape": "doublecircle",
|
||||||
|
"style": "bold",
|
||||||
|
"fontsize": "14",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_state = -1
|
||||||
|
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
queue = deque()
|
||||||
|
queue.append(self.root)
|
||||||
|
# root id is always 0
|
||||||
|
dot.node("0", label="0", **default_node_attr)
|
||||||
|
dot.edge("0", "0", label=f"*/0")
|
||||||
|
seen.add(0)
|
||||||
|
|
||||||
|
while len(queue):
|
||||||
|
current_node = queue.popleft()
|
||||||
|
for token, node in current_node.next.items():
|
||||||
|
if node.id not in seen:
|
||||||
|
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||||
|
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||||
|
"0"
|
||||||
|
).rstrip(".")
|
||||||
|
label = f"{node.id}/({node_score},{local_node_score})"
|
||||||
|
if node.is_end:
|
||||||
|
dot.node(str(node.id), label=label, **final_state_attr)
|
||||||
|
else:
|
||||||
|
dot.node(str(node.id), label=label, **default_node_attr)
|
||||||
|
seen.add(node.id)
|
||||||
|
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
|
||||||
|
label = str(token) if symbol_table is None else symbol_table[token]
|
||||||
|
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
|
||||||
|
dot.edge(
|
||||||
|
str(node.id),
|
||||||
|
str(node.fail.id),
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
queue.append(node)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
_, extension = os.path.splitext(filename)
|
||||||
|
if extension == "" or extension[0] != ".":
|
||||||
|
raise ValueError(
|
||||||
|
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
|
||||||
|
filename
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
temp_fn = dot.render(
|
||||||
|
filename="temp",
|
||||||
|
directory=tmp_dir,
|
||||||
|
format=extension[1:],
|
||||||
|
cleanup=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
shutil.move(temp_fn, filename)
|
||||||
|
|
||||||
|
return dot
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
contexts_str = ["HE", "SHE", "HIS", "HERS"]
|
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
|
||||||
contexts = []
|
contexts = []
|
||||||
for s in contexts_str:
|
for s in contexts_str:
|
||||||
contexts.append([ord(x) for x in s])
|
contexts.append([ord(x) for x in s])
|
||||||
|
|
||||||
context_graph = ContextGraph(context_score=2)
|
context_graph = ContextGraph(context_score=1)
|
||||||
context_graph.build_context_graph(contexts)
|
context_graph.build(contexts)
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
|
symbol_table = {}
|
||||||
assert score == 2, score
|
for contexts in contexts_str:
|
||||||
assert state.token == ord("H"), state.token
|
for s in contexts:
|
||||||
|
symbol_table[ord(s)] = s
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(state, ord("I"))
|
context_graph.draw(
|
||||||
assert score == 2, score
|
title="Graph for: " + " / ".join(contexts_str),
|
||||||
assert state.token == ord("I"), state.token
|
filename="context_graph.pdf",
|
||||||
|
symbol_table=symbol_table,
|
||||||
score, state = context_graph.forward_one_step(state, ord("S"))
|
)
|
||||||
assert score == 2, score
|
|
||||||
assert state.token == -1, state.token
|
|
||||||
|
|
||||||
|
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
|
||||||
|
expected_scores = [7, 6, 3, 2, 5, 3]
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
total_scores = 0
|
||||||
|
state = context_graph.root
|
||||||
|
for q in query:
|
||||||
|
score, state = context_graph.forward_one_step(state, ord(q))
|
||||||
|
total_scores += score
|
||||||
score, state = context_graph.finalize(state)
|
score, state = context_graph.finalize(state)
|
||||||
assert score == 0, score
|
|
||||||
assert state.token == -1, state.token
|
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
|
|
||||||
assert score == 2, score
|
|
||||||
assert state.token == ord("S"), state.token
|
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(state, ord("H"))
|
|
||||||
assert score == 2, score
|
|
||||||
assert state.token == ord("H"), state.token
|
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(state, ord("D"))
|
|
||||||
assert score == -4, score
|
|
||||||
assert state.token == -1, state.token
|
|
||||||
|
|
||||||
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
|
|
||||||
assert score == 0, score
|
|
||||||
assert state.token == -1, state.token
|
assert state.token == -1, state.token
|
||||||
|
total_scores += score
|
||||||
|
assert total_scores == expected_scores[i], (
|
||||||
|
total_scores,
|
||||||
|
expected_scores[i],
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user