mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fix context graph
This commit is contained in:
parent
e90563cdff
commit
306380f839
@ -29,7 +29,7 @@ class ContextState:
|
|||||||
token: int,
|
token: int,
|
||||||
token_score: float,
|
token_score: float,
|
||||||
node_score: float,
|
node_score: float,
|
||||||
local_node_score: float,
|
output_score: float,
|
||||||
is_end: bool,
|
is_end: bool,
|
||||||
):
|
):
|
||||||
"""Create a ContextState.
|
"""Create a ContextState.
|
||||||
@ -40,16 +40,15 @@ class ContextState:
|
|||||||
The id of the root node is always 0.
|
The id of the root node is always 0.
|
||||||
token:
|
token:
|
||||||
The token id.
|
The token id.
|
||||||
score:
|
token_score:
|
||||||
The bonus for each token during decoding, which will hopefully
|
The bonus for each token during decoding, which will hopefully
|
||||||
boost the token up to survive beam search.
|
boost the token up to survive beam search.
|
||||||
node_score:
|
node_score:
|
||||||
The accumulated bonus from root of graph to current node, it will be
|
The accumulated bonus from root of graph to current node, it will be
|
||||||
used to calculate the score for fail arc.
|
used to calculate the score for fail arc.
|
||||||
local_node_score:
|
output_score:
|
||||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
The total scores of matched phrases, sum of the node_score of all
|
||||||
to current_node, it will be used to calculate the score for fail arc.
|
the output node for current node.
|
||||||
Node: The local_node_score of a ``end_node`` is 0.
|
|
||||||
is_end:
|
is_end:
|
||||||
True if current token is the end of a context.
|
True if current token is the end of a context.
|
||||||
"""
|
"""
|
||||||
@ -57,7 +56,7 @@ class ContextState:
|
|||||||
self.token = token
|
self.token = token
|
||||||
self.token_score = token_score
|
self.token_score = token_score
|
||||||
self.node_score = node_score
|
self.node_score = node_score
|
||||||
self.local_node_score = local_node_score
|
self.output_score = output_score
|
||||||
self.is_end = is_end
|
self.is_end = is_end
|
||||||
self.next = {}
|
self.next = {}
|
||||||
self.fail = None
|
self.fail = None
|
||||||
@ -93,7 +92,7 @@ class ContextGraph:
|
|||||||
token=-1,
|
token=-1,
|
||||||
token_score=0,
|
token_score=0,
|
||||||
node_score=0,
|
node_score=0,
|
||||||
local_node_score=0,
|
output_score=0,
|
||||||
is_end=False,
|
is_end=False,
|
||||||
)
|
)
|
||||||
self.root.fail = self.root
|
self.root.fail = self.root
|
||||||
@ -131,6 +130,7 @@ class ContextGraph:
|
|||||||
output = None
|
output = None
|
||||||
break
|
break
|
||||||
node.output = output
|
node.output = output
|
||||||
|
node.output_score += 0 if output is None else output.output_score
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
def build(self, token_ids: List[List[int]]):
|
def build(self, token_ids: List[List[int]]):
|
||||||
@ -153,14 +153,13 @@ class ContextGraph:
|
|||||||
if token not in node.next:
|
if token not in node.next:
|
||||||
self.num_nodes += 1
|
self.num_nodes += 1
|
||||||
is_end = i == len(tokens) - 1
|
is_end = i == len(tokens) - 1
|
||||||
|
node_score = node.node_score + self.context_score
|
||||||
node.next[token] = ContextState(
|
node.next[token] = ContextState(
|
||||||
id=self.num_nodes,
|
id=self.num_nodes,
|
||||||
token=token,
|
token=token,
|
||||||
token_score=self.context_score,
|
token_score=self.context_score,
|
||||||
node_score=node.node_score + self.context_score,
|
node_score=node_score,
|
||||||
local_node_score=0
|
output_score=node_score if is_end else 0,
|
||||||
if is_end
|
|
||||||
else (node.local_node_score + self.context_score),
|
|
||||||
is_end=is_end,
|
is_end=is_end,
|
||||||
)
|
)
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
@ -186,8 +185,6 @@ class ContextGraph:
|
|||||||
if token in state.next:
|
if token in state.next:
|
||||||
node = state.next[token]
|
node = state.next[token]
|
||||||
score = node.token_score
|
score = node.token_score
|
||||||
if state.is_end:
|
|
||||||
score += state.node_score
|
|
||||||
else:
|
else:
|
||||||
# token not matched
|
# token not matched
|
||||||
# We will trace along the fail arc until it matches the token or reaching
|
# We will trace along the fail arc until it matches the token or reaching
|
||||||
@ -202,14 +199,9 @@ class ContextGraph:
|
|||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
|
|
||||||
# The score of the fail path
|
# The score of the fail path
|
||||||
score = node.node_score - state.local_node_score
|
score = node.node_score - state.node_score
|
||||||
assert node is not None
|
assert node is not None
|
||||||
matched_score = 0
|
return (score + node.output_score, node)
|
||||||
output = node.output
|
|
||||||
while output is not None:
|
|
||||||
matched_score += output.node_score
|
|
||||||
output = output.output
|
|
||||||
return (score + matched_score, node)
|
|
||||||
|
|
||||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
"""When reaching the end of the decoded sequence, we need to finalize
|
"""When reaching the end of the decoded sequence, we need to finalize
|
||||||
@ -227,8 +219,6 @@ class ContextGraph:
|
|||||||
"""
|
"""
|
||||||
# The score of the fail arc
|
# The score of the fail arc
|
||||||
score = -state.node_score
|
score = -state.node_score
|
||||||
if state.is_end:
|
|
||||||
score = 0
|
|
||||||
return (score, self.root)
|
return (score, self.root)
|
||||||
|
|
||||||
def draw(
|
def draw(
|
||||||
@ -307,10 +297,8 @@ class ContextGraph:
|
|||||||
for token, node in current_node.next.items():
|
for token, node in current_node.next.items():
|
||||||
if node.id not in seen:
|
if node.id not in seen:
|
||||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
|
||||||
"0"
|
label = f"{node.id}/({node_score}, {output_score})"
|
||||||
).rstrip(".")
|
|
||||||
label = f"{node.id}/({node_score},{local_node_score})"
|
|
||||||
if node.is_end:
|
if node.is_end:
|
||||||
dot.node(str(node.id), label=label, **final_state_attr)
|
dot.node(str(node.id), label=label, **final_state_attr)
|
||||||
else:
|
else:
|
||||||
@ -391,6 +379,7 @@ if __name__ == "__main__":
|
|||||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||||
"SHED": 6, # "S", "SHE", "HE"
|
"SHED": 6, # "S", "SHE", "HE"
|
||||||
|
"SHELF": 6, # "S", "SHE", "HE"
|
||||||
"HELL": 2, # "HE"
|
"HELL": 2, # "HE"
|
||||||
"HELLO": 7, # "HE", "HELLO"
|
"HELLO": 7, # "HE", "HELLO"
|
||||||
"DHRHISQ": 4, # "HIS", "S"
|
"DHRHISQ": 4, # "HIS", "S"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user