mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes to forward_one_step; add draw to context graph
This commit is contained in:
parent
62557a1564
commit
40a05810dd
@ -919,7 +919,7 @@ def main():
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build_context_graph(sp.encode(contexts))
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
||||
@ -855,7 +855,7 @@ def main():
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(graph_compiler.texts_to_ids(line.strip()))
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build_context_graph(contexts)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
||||
@ -14,7 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
import os
|
||||
import shutil
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class ContextState:
|
||||
@ -22,28 +25,39 @@ class ContextState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
token: int,
|
||||
score: float,
|
||||
total_score: float,
|
||||
token_score: float,
|
||||
node_score: float,
|
||||
local_node_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
|
||||
Args:
|
||||
id:
|
||||
The node id, only for visualization now. A node is in [0, graph.num_nodes).
|
||||
The id of the root node is always 0.
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
total_score:
|
||||
node_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
local_node_score:
|
||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||
to current_node, it will be used to calculate the score for fail arc.
|
||||
Node: The local_node_score of a ``end_node`` is 0.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
self.id = id
|
||||
self.token = token
|
||||
self.score = score
|
||||
self.total_score = total_score
|
||||
self.token_score = token_score
|
||||
self.node_score = node_score
|
||||
self.local_node_score = local_node_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
@ -72,7 +86,15 @@ class ContextGraph:
|
||||
word/phrase will have larger bonus score, they have to be matched though).
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
|
||||
self.num_nodes = 0
|
||||
self.root = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=-1,
|
||||
token_score=0,
|
||||
node_score=0,
|
||||
local_node_score=0,
|
||||
is_end=False,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
|
||||
def _fill_fail(self):
|
||||
@ -81,12 +103,12 @@ class ContextGraph:
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||
details of the algorithm.
|
||||
"""
|
||||
queue = []
|
||||
queue = deque()
|
||||
for token, node in self.root.next.items():
|
||||
node.fail = self.root
|
||||
queue.append(node)
|
||||
while queue:
|
||||
current_node = queue.pop(0)
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
fail = current_node.fail
|
||||
if token in fail.next:
|
||||
@ -102,7 +124,7 @@ class ContextGraph:
|
||||
node.fail = fail
|
||||
queue.append(node)
|
||||
|
||||
def build_context_graph(self, token_ids: List[List[int]]):
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
@ -120,13 +142,17 @@ class ContextGraph:
|
||||
node = self.root
|
||||
for i, token in enumerate(tokens):
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
score=self.context_score,
|
||||
# The total score is the accumulated score from root to current node,
|
||||
# it will be used to calculate the score of fail arc later.
|
||||
total_score=node.total_score + self.context_score,
|
||||
is_end=i == len(tokens) - 1,
|
||||
token_score=self.context_score,
|
||||
node_score=node.node_score + self.context_score,
|
||||
local_node_score=0
|
||||
if is_end
|
||||
else (node.local_node_score + self.context_score),
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail()
|
||||
@ -138,7 +164,7 @@ class ContextGraph:
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given state (trie node) to start.
|
||||
The given token containing trie node to start.
|
||||
token:
|
||||
The given token.
|
||||
|
||||
@ -148,9 +174,7 @@ class ContextGraph:
|
||||
# token matched
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.score
|
||||
if node.is_end:
|
||||
node = self.root
|
||||
score = node.token_score
|
||||
return (score, node)
|
||||
else:
|
||||
# token not matched
|
||||
@ -164,10 +188,9 @@ class ContextGraph:
|
||||
|
||||
if token in node.next:
|
||||
node = node.next[token]
|
||||
# The score of the fail arc
|
||||
score = node.total_score - state.total_score
|
||||
if node.is_end:
|
||||
node = self.root
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
return (score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
@ -185,49 +208,161 @@ class ContextGraph:
|
||||
to root. The next state is always root.
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = self.root.total_score - state.total_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
score = self.root.node_score - state.local_node_score
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
title: Optional[str] = None,
|
||||
filename: Optional[str] = "",
|
||||
symbol_table: Optional[Dict[int, str]] = None,
|
||||
) -> "Digraph": # noqa
|
||||
|
||||
"""Visualize a ContextGraph via graphviz.
|
||||
|
||||
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||
and optionally save to file `filename`.
|
||||
`filename` must have a suffix that graphviz understands, such as
|
||||
`pdf`, `svg` or `png`.
|
||||
|
||||
Note:
|
||||
You need to install graphviz to use this function::
|
||||
|
||||
pip install graphviz
|
||||
|
||||
Args:
|
||||
title:
|
||||
Title to be displayed in image, e.g. 'A simple FSA example'
|
||||
filename:
|
||||
Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg',
|
||||
'foo.png' (must have a suffix that graphviz understands).
|
||||
symbol_table:
|
||||
Map the token ids to symbols.
|
||||
Returns:
|
||||
A Diagraph from grahpviz.
|
||||
"""
|
||||
|
||||
try:
|
||||
import graphviz
|
||||
except Exception:
|
||||
print("You cannot use `to_dot` unless the graphviz package is installed.")
|
||||
raise
|
||||
|
||||
graph_attr = {
|
||||
"rankdir": "LR",
|
||||
"size": "8.5,11",
|
||||
"center": "1",
|
||||
"orientation": "Portrait",
|
||||
"ranksep": "0.4",
|
||||
"nodesep": "0.25",
|
||||
}
|
||||
if title is not None:
|
||||
graph_attr["label"] = title
|
||||
|
||||
default_node_attr = {
|
||||
"shape": "circle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state_attr = {
|
||||
"shape": "doublecircle",
|
||||
"style": "bold",
|
||||
"fontsize": "14",
|
||||
}
|
||||
|
||||
final_state = -1
|
||||
dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr)
|
||||
|
||||
seen = set()
|
||||
queue = deque()
|
||||
queue.append(self.root)
|
||||
# root id is always 0
|
||||
dot.node("0", label="0", **default_node_attr)
|
||||
dot.edge("0", "0", label=f"*/0")
|
||||
seen.add(0)
|
||||
|
||||
while len(queue):
|
||||
current_node = queue.popleft()
|
||||
for token, node in current_node.next.items():
|
||||
if node.id not in seen:
|
||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||
"0"
|
||||
).rstrip(".")
|
||||
label = f"{node.id}/({node_score},{local_node_score})"
|
||||
if node.is_end:
|
||||
dot.node(str(node.id), label=label, **final_state_attr)
|
||||
else:
|
||||
dot.node(str(node.id), label=label, **default_node_attr)
|
||||
seen.add(node.id)
|
||||
weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".")
|
||||
label = str(token) if symbol_table is None else symbol_table[token]
|
||||
dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}")
|
||||
dot.edge(
|
||||
str(node.id),
|
||||
str(node.fail.id),
|
||||
color="red",
|
||||
)
|
||||
queue.append(node)
|
||||
|
||||
if filename:
|
||||
_, extension = os.path.splitext(filename)
|
||||
if extension == "" or extension[0] != ".":
|
||||
raise ValueError(
|
||||
"Filename needs to have a suffix like .png, .pdf, .svg: {}".format(
|
||||
filename
|
||||
)
|
||||
)
|
||||
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_fn = dot.render(
|
||||
filename="temp",
|
||||
directory=tmp_dir,
|
||||
format=extension[1:],
|
||||
cleanup=True,
|
||||
)
|
||||
|
||||
shutil.move(temp_fn, filename)
|
||||
|
||||
return dot
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
contexts_str = ["HE", "SHE", "HIS", "HERS"]
|
||||
contexts_str = ["HE", "SHE", "SHELL", "HIS", "HERS", "HELLO"]
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
|
||||
context_graph = ContextGraph(context_score=2)
|
||||
context_graph.build_context_graph(contexts)
|
||||
context_graph = ContextGraph(context_score=1)
|
||||
context_graph.build(contexts)
|
||||
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("H"), state.token
|
||||
symbol_table = {}
|
||||
for contexts in contexts_str:
|
||||
for s in contexts:
|
||||
symbol_table[ord(s)] = s
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("I"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("I"), state.token
|
||||
context_graph.draw(
|
||||
title="Graph for: " + " / ".join(contexts_str),
|
||||
filename="context_graph.pdf",
|
||||
symbol_table=symbol_table,
|
||||
)
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("S"))
|
||||
assert score == 2, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.finalize(state)
|
||||
assert score == 0, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("S"), state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("H"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("H"), state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("D"))
|
||||
assert score == -4, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
|
||||
assert score == 0, score
|
||||
assert state.token == -1, state.token
|
||||
queries = ["HERSHE", "HISHE", "SHED", "HELL", "HELLO", "DHRHISQ"]
|
||||
expected_scores = [7, 6, 3, 2, 5, 3]
|
||||
for i, query in enumerate(queries):
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert total_scores == expected_scores[i], (
|
||||
total_scores,
|
||||
expected_scores[i],
|
||||
query,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user