diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 264054dd2..1936f43b5 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -20,7 +20,13 @@ from typing import Dict, List, Tuple class ContextState: """The state in ContextGraph""" - def __init__(self, token: int, score: float, total_score: float, is_end: bool): + def __init__( + self, + token: int, + score: float, + total_score: float, + is_end: bool, + ): """Create a ContextState. Args: @@ -81,11 +87,18 @@ class ContextGraph: queue.append(node) while queue: current_node = queue.pop(0) - current_fail = current_node.fail for token, node in current_node.next.items(): - fail = current_fail - if token in current_fail.next: - fail = current_fail.next[token] + fail = current_node.fail + if token in fail.next: + fail = fail.next[token] + else: + fail = fail.fail + while token not in fail.next: + fail = fail.fail + if fail.token == -1: # root + break + if token in fail.next: + fail = fail.next[token] node.fail = fail queue.append(node) @@ -116,7 +129,7 @@ class ContextGraph: is_end=i == len(tokens) - 1, ) node = node.next[token] - self._fill_fail() + self._fill_fail() def forward_one_step( self, state: ContextState, token: int @@ -136,8 +149,6 @@ class ContextGraph: if token in state.next: node = state.next[token] score = node.score - # if the matched node is the end of a word/phrase, we will start - # from the root for next token. if node.is_end: node = self.root return (score, node)