mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fix some bugs
This commit is contained in:
parent
2e7e7875f5
commit
62557a1564
@ -20,7 +20,13 @@ from typing import Dict, List, Tuple
|
||||
class ContextState:
|
||||
"""The state in ContextGraph"""
|
||||
|
||||
def __init__(self, token: int, score: float, total_score: float, is_end: bool):
|
||||
def __init__(
|
||||
self,
|
||||
token: int,
|
||||
score: float,
|
||||
total_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
|
||||
Args:
|
||||
@ -81,11 +87,18 @@ class ContextGraph:
|
||||
queue.append(node)
|
||||
while queue:
|
||||
current_node = queue.pop(0)
|
||||
current_fail = current_node.fail
|
||||
for token, node in current_node.next.items():
|
||||
fail = current_fail
|
||||
if token in current_fail.next:
|
||||
fail = current_fail.next[token]
|
||||
fail = current_node.fail
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
else:
|
||||
fail = fail.fail
|
||||
while token not in fail.next:
|
||||
fail = fail.fail
|
||||
if fail.token == -1: # root
|
||||
break
|
||||
if token in fail.next:
|
||||
fail = fail.next[token]
|
||||
node.fail = fail
|
||||
queue.append(node)
|
||||
|
||||
@ -116,7 +129,7 @@ class ContextGraph:
|
||||
is_end=i == len(tokens) - 1,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail()
|
||||
self._fill_fail()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
@ -136,8 +149,6 @@ class ContextGraph:
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.score
|
||||
# if the matched node is the end of a word/phrase, we will start
|
||||
# from the root for next token.
|
||||
if node.is_end:
|
||||
node = self.root
|
||||
return (score, node)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user