From 2e7e7875f5c3de70b519cbd0aef511fb483d1855 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 8 May 2023 12:22:20 +0800 Subject: [PATCH] Implement Aho-Corasick context graph --- .../beam_search.py | 17 +- .../pruned_transducer_stateless4/decode.py | 24 +- .../pruned_transducer_stateless5/decode.py | 32 +- icefall/context_graph.py | 408 ++++++++---------- 4 files changed, 223 insertions(+), 258 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d0ee60c29..33606768a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -768,9 +768,6 @@ class Hypothesis: """Return a string representation of self.ys""" return "_".join(map(str, self.ys)) - def __str__(self) -> str: - return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}" - class HypothesisList(object): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: @@ -919,7 +916,6 @@ def modified_beam_search( encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, context_graph: Optional[ContextGraph] = None, - num_context_history: int = 1, beam: int = 4, temperature: float = 1.0, return_timestamps: bool = False, @@ -971,7 +967,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history), + context_state=None if context_graph is None else context_graph.root, timestamp=[], ) ) @@ -1056,12 +1052,15 @@ def modified_beam_search( new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state.clone() + new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_timestamp.append(t) if context_graph is not None: - context_score, new_context_state = hyp.context_state.forward_one_step(new_token) + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) new_log_prob = topk_log_probs[k] + context_score @@ -1081,7 +1080,9 @@ def modified_beam_search( finalized_B = [HypothesisList() for _ in range(len(B))] for i, hyps in enumerate(B): for hyp in list(hyps): - context_score, new_context_state = hyp.context_state.finalize() + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) finalized_B[i].add( Hypothesis( ys=hyp.ys, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 9d658de6c..6a8e82fff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -362,21 +362,20 @@ def get_parser(): "--context-score", type=float, default=2, - help="", - ) - - parser.add_argument( - "--num-context-history", - type=int, - default=1, - help="", + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, ) parser.add_argument( "--context-file", type=str, default="", - help="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, ) add_model_arguments(parser) @@ -522,7 +521,6 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, - num_context_history=params.num_context_history, return_timestamps=True, ) else: @@ -579,7 +577,6 @@ def decode_one_batch( else: key = f"beam_size_{params.beam_size}" key += f"-context-score-{params.context_score}" - key += f"-num-context-history-{params.num_context_history}" return {key: (hyps, timestamps)} @@ -629,7 +626,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 1 + log_interval = 1 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -785,7 +782,6 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-context-score-{params.context_score}" - params.suffix += f"-num-context-history-{params.num_context_history}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -923,7 +919,7 @@ def main(): for line in open(params.context_file).readlines(): contexts.append(line.strip()) context_graph = ContextGraph(params.context_score) - context_graph.build_context_graph_bpe(contexts, sp) + context_graph.build_context_graph(sp.encode(contexts)) else: context_graph = None else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index eb9962e3c..799992ae6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -136,6 +136,7 @@ from beam_search import ( from train import add_model_arguments, get_params, get_transducer_model from icefall import ContextGraph, LmScorer, NgramLm +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -313,21 +314,20 @@ def get_parser(): "--context-score", type=float, default=2, - help="", - ) - - parser.add_argument( - "--num-context-history", - type=int, - default=1, - help="", + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, ) parser.add_argument( "--context-file", type=str, default="", - help="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, ) parser.add_argument( @@ -472,7 +472,6 @@ def decode_one_batch( beam=params.beam_size, encoder_out_lens=encoder_out_lens, context_graph=context_graph, - num_context_history=params.num_context_history, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -535,7 +534,7 @@ def decode_one_batch( } else: return { - f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps + f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps } @@ -685,7 +684,6 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-context-score-{params.context_score}" - params.suffix += f"-num-context-history-{params.num_context_history}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -715,11 +713,15 @@ def main(): logging.info(f"Device: {device}") - # import pdb; pdb.set_trace() lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + if params.simulate_streaming: assert ( params.causal_convolution @@ -851,9 +853,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append(graph_compiler.texts_to_ids(line.strip())) context_graph = ContextGraph(params.context_score) - context_graph.build_context_graph_char(contexts, lexicon.token_table) + context_graph.build_context_graph(contexts) else: context_graph = None else: diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 6cbcd043a..264054dd2 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -14,243 +14,209 @@ # See the License for the specific language governing permissions and # limitations under the License. -from heapq import heappush, heappop -import re -from dataclasses import dataclass -from typing import List, Tuple -import argparse -import k2 -import kaldifst -import sentencepiece as spm - -from icefall.utils import is_module_available - - -class ContextGraph: - def __init__(self, context_score: float = 1): - self.context_score = context_score - - def build_context_graph_char( - self, contexts: List[str], token_table: k2.SymbolTable - ): - """Convert a list of texts to a list-of-list of token IDs. - - Args: - contexts: - It is a list of strings. - An example containing two strings is given below: - - ['你好中国', '北京欢迎您'] - token_table: - The SymbolTable containing tokens and corresponding ids. - - Returns: - Return a list-of-list of token IDs. - """ - ids: List[List[int]] = [] - whitespace = re.compile(r"([ \t])") - for text in contexts: - text = re.sub(whitespace, "", text) - sub_ids: List[int] = [] - skip = False - for txt in text: - if txt not in token_table: - skip = True - break - sub_ids.append(token_table[txt]) - if skip: - logging.warning(f"Skipping context {text}, as it has OOV char.") - continue - ids.append(sub_ids) - self.build_context_graph(ids) - - def build_context_graph_bpe( - self, contexts: List[str], sp: spm.SentencePieceProcessor - ): - contexts_bpe = sp.encode(contexts) - self.build_context_graph(contexts_bpe) - - def build_context_graph(self, token_ids: List[List[int]]): - graph = kaldifst.StdVectorFst() - start_state = ( - graph.add_state() - ) # 1st state will be state 0 (returned by add_state) - assert start_state == 0, start_state - graph.start = 0 # set the start state to 0 - graph.set_final(start_state, weight=kaldifst.TropicalWeight.one) - - for tokens in token_ids: - prev_state = start_state - next_state = start_state - backoff_score = 0 - for i in range(len(tokens)): - score = self.context_score - next_state = graph.add_state() if i < len(tokens) - 1 else start_state - graph.add_arc( - state=prev_state, - arc=kaldifst.StdArc( - ilabel=tokens[i], - olabel=tokens[i], - weight=score, - nextstate=next_state, - ), - ) - if i > 0: - graph.add_arc( - state=prev_state, - arc=kaldifst.StdArc( - ilabel=0, - olabel=0, - weight=-backoff_score, - nextstate=start_state, - ), - ) - prev_state = next_state - backoff_score += score - self.graph = kaldifst.determinize(graph) - kaldifst.arcsort(self.graph) - - def is_final_state(self, state_id: int) -> bool: - return self.graph.final(state_id) == kaldifst.TropicalWeight.one - - - def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]: - arc_iter = kaldifst.ArcIterator(self.graph, state_id) - num_arcs = self.graph.num_arcs(state_id) - - # The LM is arc sorted by ilabel, so we use binary search below. - left = 0 - right = num_arcs - 1 - while left <= right: - mid = (left + right) // 2 - arc_iter.seek(mid) - arc = arc_iter.value - if arc.ilabel < label: - left = mid + 1 - elif arc.ilabel > label: - right = mid - 1 - else: - return (arc.nextstate, arc.weight.value, True) - - # Backoff to state 0 with the score on epsilon arc (ilabel == 0) - arc_iter.seek(0) - arc = arc_iter.value - if arc.ilabel == 0: - return (0, 0, False) - else: - return (0, arc.weight.value, False) +from typing import Dict, List, Tuple class ContextState: - def __init__(self, graph: ContextGraph, max_states: int): - self.graph = graph - self.max_states = max_states - # [(total score, (score, state_id))] - self.states: List[Tuple[float, Tuple[float, int]]] = [] + """The state in ContextGraph""" - def __str__(self): - return ";".join([str(state) for state in self.states]) + def __init__(self, token: int, score: float, total_score: float, is_end: bool): + """Create a ContextState. - def clone(self): - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - new_context_state.states = self.states[:] - return new_context_state + Args: + token: + The token id. + score: + The bonus for each token during decoding, which will hopefully + boost the token up to survive beam search. + total_score: + The accumulated bonus from root of graph to current node, it will be + used to calculate the score for fail arc. + is_end: + True if current token is the end of a context. + """ + self.token = token + self.score = score + self.total_score = total_score + self.is_end = is_end + self.next = {} + self.fail = None - def finalize(self) -> float: - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - if len(self.states) == 0: - return 0, new_context_state - item = heappop(self.states) - return item[0], new_context_state - def forward_one_step(self, label: int) -> float: - states = self.states[:] - new_states = [] - # expand current label from state state - status = self.graph.get_next_state(0, label) - if status[2]: - heappush(new_states, (-status[1], (status[1], status[0]))) +class ContextGraph: + """The ContextGraph is modified from Aho-Corasick which is mainly + a Trie with a fail arc for each node. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details + of Aho-Corasick algorithm. + + A ContextGraph contains some words / phrases that we expect to boost their + scores during decoding. If the substring of a decoded sequence matches the word / phrase + in the ContextGraph, we will give the decoded sequence a bonus to make it survive + beam search. + """ + + def __init__(self, context_score: float): + """Initialize a ContextGraph with the given ``context_score``. + + A root node will be created (**NOTE:** the token of root is hardcoded to -1). + + Args: + context_score: + The bonus score for each token(note: NOT for each word/phrase, it means longer + word/phrase will have larger bonus score, they have to be matched though). + """ + self.context_score = context_score + self.root = ContextState(token=-1, score=0, total_score=0, is_end=False) + self.root.fail = self.root + + def _fill_fail(self): + """This function fills the fail arc for each trie node, it can be computed + in linear time by performing a breadth-first search starting from the root. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the + details of the algorithm. + """ + queue = [] + for token, node in self.root.next.items(): + node.fail = self.root + queue.append(node) + while queue: + current_node = queue.pop(0) + current_fail = current_node.fail + for token, node in current_node.next.items(): + fail = current_fail + if token in current_fail.next: + fail = current_fail.next[token] + node.fail = fail + queue.append(node) + + def build_context_graph(self, token_ids: List[List[int]]): + """Build the ContextGraph from a list of token list. + It first build a trie from the given token lists, then fill the fail arc + for each trie node. + + See https://en.wikipedia.org/wiki/Trie for how to build a trie. + + Args: + token_ids: + The given token lists to build the ContextGraph, it is a list of token list, + each token list contains the token ids for a word/phrase. The token id + could be an id of a char (modeling with single Chinese char) or an id + of a BPE (modeling with BPEs). + """ + for tokens in token_ids: + node = self.root + for i, token in enumerate(tokens): + if token not in node.next: + node.next[token] = ContextState( + token=token, + score=self.context_score, + # The total score is the accumulated score from root to current node, + # it will be used to calculate the score of fail arc later. + total_score=node.total_score + self.context_score, + is_end=i == len(tokens) - 1, + ) + node = node.next[token] + self._fill_fail() + + def forward_one_step( + self, state: ContextState, token: int + ) -> Tuple[float, ContextState]: + """Search the graph with given state and token. + + Args: + state: + The given state (trie node) to start. + token: + The given token. + + Returns: + Return a tuple of score and next state. + """ + # token matched + if token in state.next: + node = state.next[token] + score = node.score + # if the matched node is the end of a word/phrase, we will start + # from the root for next token. + if node.is_end: + node = self.root + return (score, node) else: - assert status[0] == 0 and status[2] == False, status + # token not matched + # We will trace along the fail arc until it matches the token or reaching + # root of the graph. + node = state.fail + while token not in node.next: + node = node.fail + if node.token == -1: # root + break - # the score we have added to the path till now - prev_max_total_score = 0 - # expand previous states with given label - while states: - state = heappop(states) - if -state[0] > prev_max_total_score: - prev_max_total_score = -state[0] + if token in node.next: + node = node.next[token] + # The score of the fail arc + score = node.total_score - state.total_score + if node.is_end: + node = self.root + return (score, node) - status = self.graph.get_next_state(state[1][1], label) + def finalize(self, state: ContextState) -> Tuple[float, ContextState]: + """When reaching the end of the decoded sequence, we need to finalize + the matching, the purpose is to subtract the added bonus score for the + state that is not the end of a word/phrase. - if status[2]: - heappush(new_states, (state[0] - status[1], (status[1], status[0]))) - else: - pass - # assert status == (0, state[0], False), status - num_states_drop = ( - 0 - if len(new_states) <= self.max_states - else len(new_states) - self.max_states - ) - - states = [] - if len(new_states) == 0: - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - return -prev_max_total_score, new_context_state - - item = heappop(new_states) - - # if one item match a context, clear all states (means start a new context - # from next label), and return the score of current label - if self.graph.is_final_state(item[1][1]): - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - return -item[0] - prev_max_total_score, new_context_state - - max_total_score = -item[0] - heappush(states, item) - - while num_states_drop != 0: - item = heappop(new_states) - if self.graph.is_final_state(item[1][1]): - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - return -item[0] - prev_max_total_score, new_context_state - num_states_drop -= 1 - - while new_states: - item = heappop(new_states) - if self.graph.is_final_state(item[1][1]): - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - return -item[0] - prev_max_total_score, new_context_state - heappush(states, item) - # no context matched, the matching may continue with previous prefix, - # or change to another prefix. - new_context_state = ContextState(graph=self.graph, max_states=self.max_states) - new_context_state.states = states - return max_total_score - prev_max_total_score, new_context_state + Args: + state: + The given state(trie node). + Returns: + Return a tuple of score and next state. If state is the end of a word/phrase + the score is zero, otherwise the score is the score of a implicit fail arc + to root. The next state is always root. + """ + # The score of the fail arc + score = self.root.total_score - state.total_score + if state.is_end: + score = 0 + return (score, self.root) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--bpe_model", - type=str, - help="Path to bpe model", - ) - args = parser.parse_args() + contexts_str = ["HE", "SHE", "HIS", "HERS"] + contexts = [] + for s in contexts_str: + contexts.append([ord(x) for x in s]) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) + context_graph = ContextGraph(context_score=2) + context_graph.build_context_graph(contexts) - contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"] - context_graph = ContextGraph() - context_graph.build_context_graph_bpe(contexts, sp) + score, state = context_graph.forward_one_step(context_graph.root, ord("H")) + assert score == 2, score + assert state.token == ord("H"), state.token - if not is_module_available("graphviz"): - raise ValueError("Please 'pip install graphviz' first.") - import graphviz + score, state = context_graph.forward_one_step(state, ord("I")) + assert score == 2, score + assert state.token == ord("I"), state.token - fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True) - fst_source = graphviz.Source(fst_dot) - fst_source.render(outfile="context_graph.svg") + score, state = context_graph.forward_one_step(state, ord("S")) + assert score == 2, score + assert state.token == -1, state.token + + score, state = context_graph.finalize(state) + assert score == 0, score + assert state.token == -1, state.token + + score, state = context_graph.forward_one_step(context_graph.root, ord("S")) + assert score == 2, score + assert state.token == ord("S"), state.token + + score, state = context_graph.forward_one_step(state, ord("H")) + assert score == 2, score + assert state.token == ord("H"), state.token + + score, state = context_graph.forward_one_step(state, ord("D")) + assert score == -4, score + assert state.token == -1, state.token + + score, state = context_graph.forward_one_step(context_graph.root, ord("D")) + assert score == 0, score + assert state.token == -1, state.token