mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
fix context graph
This commit is contained in:
parent
e90563cdff
commit
306380f839
@ -29,7 +29,7 @@ class ContextState:
|
||||
token: int,
|
||||
token_score: float,
|
||||
node_score: float,
|
||||
local_node_score: float,
|
||||
output_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
@ -40,16 +40,15 @@ class ContextState:
|
||||
The id of the root node is always 0.
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
token_score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
node_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
local_node_score:
|
||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||
to current_node, it will be used to calculate the score for fail arc.
|
||||
Node: The local_node_score of a ``end_node`` is 0.
|
||||
output_score:
|
||||
The total scores of matched phrases, sum of the node_score of all
|
||||
the output node for current node.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
@ -57,7 +56,7 @@ class ContextState:
|
||||
self.token = token
|
||||
self.token_score = token_score
|
||||
self.node_score = node_score
|
||||
self.local_node_score = local_node_score
|
||||
self.output_score = output_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
@ -93,7 +92,7 @@ class ContextGraph:
|
||||
token=-1,
|
||||
token_score=0,
|
||||
node_score=0,
|
||||
local_node_score=0,
|
||||
output_score=0,
|
||||
is_end=False,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
@ -131,6 +130,7 @@ class ContextGraph:
|
||||
output = None
|
||||
break
|
||||
node.output = output
|
||||
node.output_score += 0 if output is None else output.output_score
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
@ -153,14 +153,13 @@ class ContextGraph:
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
node_score = node.node_score + self.context_score
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=self.context_score,
|
||||
node_score=node.node_score + self.context_score,
|
||||
local_node_score=0
|
||||
if is_end
|
||||
else (node.local_node_score + self.context_score),
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
@ -186,8 +185,6 @@ class ContextGraph:
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.token_score
|
||||
if state.is_end:
|
||||
score += state.node_score
|
||||
else:
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
@ -202,14 +199,9 @@ class ContextGraph:
|
||||
node = node.next[token]
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
score = node.node_score - state.node_score
|
||||
assert node is not None
|
||||
matched_score = 0
|
||||
output = node.output
|
||||
while output is not None:
|
||||
matched_score += output.node_score
|
||||
output = output.output
|
||||
return (score + matched_score, node)
|
||||
return (score + node.output_score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
@ -227,8 +219,6 @@ class ContextGraph:
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = -state.node_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
@ -307,10 +297,8 @@ class ContextGraph:
|
||||
for token, node in current_node.next.items():
|
||||
if node.id not in seen:
|
||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||
"0"
|
||||
).rstrip(".")
|
||||
label = f"{node.id}/({node_score},{local_node_score})"
|
||||
output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
|
||||
label = f"{node.id}/({node_score}, {output_score})"
|
||||
if node.is_end:
|
||||
dot.node(str(node.id), label=label, **final_state_attr)
|
||||
else:
|
||||
@ -391,6 +379,7 @@ if __name__ == "__main__":
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 6, # "S", "SHE", "HE"
|
||||
"SHELF": 6, # "S", "SHE", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 7, # "HE", "HELLO"
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
|
Loading…
x
Reference in New Issue
Block a user