From ba257efbcddf990edeee615e4af541ae378da3e2 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 3 Jun 2023 21:28:49 +0800 Subject: [PATCH] Add Context biasing (#1038) * Add context biasing for librispeech * Add context biasing for wenetspeech * fix bugs * Implement Aho-Corasick context graph * fix some bugs * Fixes to forward_one_step; add draw to context graph * add output arc; fix black * Fix wenetspeech tokenizer * Minor fixes to the decode.py --- .../pruned_transducer_stateless7/decode.py | 74 +++- .../beam_search.py | 44 +- .../pruned_transducer_stateless4/decode.py | 57 ++- .../asr_datamodule.py | 4 +- .../pruned_transducer_stateless5/decode.py | 85 +++- icefall/__init__.py | 2 + icefall/context_graph.py | 412 ++++++++++++++++++ 7 files changed, 652 insertions(+), 26 deletions(-) create mode 100644 icefall/context_graph.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py index af54af8da..be58c4e43 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py @@ -58,6 +58,7 @@ Usage: import argparse import logging +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -76,6 +77,8 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -211,6 +214,26 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + add_model_arguments(parser) return parser @@ -222,6 +245,7 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -285,6 +309,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) else: hyp_tokens = [] @@ -324,7 +349,12 @@ def decode_one_batch( ): hyps } else: - return {f"beam_size_{params.beam_size}": hyps} + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} def decode_dataset( @@ -333,6 +363,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -377,6 +408,7 @@ def decode_dataset( model=model, token_table=token_table, decoding_graph=decoding_graph, + context_graph=context_graph, batch=batch, ) @@ -407,16 +439,17 @@ def save_results( for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + store_transcripts(filename=recog_path, texts=results_char) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -457,6 +490,12 @@ def main(): "fast_beam_search", "modified_beam_search", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -470,6 +509,10 @@ def main(): params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -490,6 +533,11 @@ def main(): params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + logging.info(params) logging.info("About to create model") @@ -586,6 +634,19 @@ def main(): else: decoding_graph = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -608,6 +669,7 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + context_graph=context_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index f5f15808d..25b79d600 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -24,7 +24,7 @@ import sentencepiece as spm import torch from model import Transducer -from icefall import NgramLm, NgramLmStateCost +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel @@ -765,6 +765,9 @@ class Hypothesis: # N-gram LM state state_cost: Optional[NgramLmStateCost] = None + # Context graph state + context_state: Optional[ContextState] = None + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -917,6 +920,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, beam: int = 4, temperature: float = 1.0, return_timestamps: bool = False, @@ -968,6 +972,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=None if context_graph is None else context_graph.root, timestamp=[], ) ) @@ -990,6 +995,7 @@ def modified_beam_search( hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] ys_log_probs = torch.cat( @@ -1047,21 +1053,51 @@ def modified_beam_search( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_timestamp.append(t) + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + new_log_prob = topk_log_probs[k] + context_score - new_log_prob = topk_log_probs[k] new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=new_context_state, ) B[i].add(new_hyp) B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 79d919ab1..524366068 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -125,6 +125,7 @@ For example: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -146,6 +147,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import ContextGraph from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -353,6 +355,27 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + add_model_arguments(parser) return parser @@ -365,6 +388,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -494,6 +518,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, return_timestamps=True, ) else: @@ -548,7 +573,12 @@ def decode_one_batch( return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": (hyps, timestamps)} + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: (hyps, timestamps)} def decode_dataset( @@ -558,6 +588,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. @@ -622,6 +653,7 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + context_graph=context_graph, ) for name, (hyps, timestamps_hyp) in hyps_dict.items(): @@ -728,6 +760,12 @@ def main(): "fast_beam_search_nbest_oracle", "modified_beam_search", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: @@ -750,6 +788,10 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-context-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -881,6 +923,18 @@ def main(): decoding_graph = None word_table = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -905,6 +959,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, ) save_results( diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c9e30e737..7cb2e1048 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=300, + default=30, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -364,7 +364,7 @@ class WenetSpeechAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") + logging.info("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 46ba6b005..dc431578c 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -92,7 +92,7 @@ When training with the L subset, the streaming usage: --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 - + (4) modified beam search with RNNLM shallow fusion ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ @@ -112,8 +112,10 @@ When training with the L subset, the streaming usage: import argparse +import glob import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -133,7 +135,8 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall import LmScorer, NgramLm +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -307,6 +310,26 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -362,6 +385,7 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, LM: Optional[LmScorer] = None, @@ -402,14 +426,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -448,6 +471,7 @@ def decode_one_batch( encoder_out=encoder_out, beam=params.beam_size, encoder_out_lens=encoder_out_lens, + context_graph=context_graph, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -509,7 +533,12 @@ def decode_one_batch( ): hyps } else: - return {f"beam_size_{params.beam_size}": hyps} + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} def decode_dataset( @@ -518,6 +547,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, LM: Optional[LmScorer] = None, @@ -567,6 +597,7 @@ def decode_dataset( lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, + context_graph=context_graph, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, LM=LM, @@ -646,6 +677,12 @@ def main(): "modified_beam_search_lm_shallow_fusion", "modified_beam_search_LODR", ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" @@ -655,6 +692,10 @@ def main(): params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -684,11 +725,15 @@ def main(): logging.info(f"Device: {device}") - # import pdb; pdb.set_trace() lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + if params.simulate_streaming: assert ( params.causal_convolution @@ -816,6 +861,19 @@ def main(): else: decoding_graph = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -833,15 +891,16 @@ def main(): test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_dls = [dev_dl, test_net_dl, test_meeting_dl] - for test_set, test_dl in zip(test_sets, test_dl): + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, model=model, lexicon=lexicon, decoding_graph=decoding_graph, + context_graph=context_graph, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, LM=LM, diff --git a/icefall/__init__.py b/icefall/__init__.py index 5d846b41d..05e2b408c 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -23,6 +23,8 @@ from .checkpoint import ( save_checkpoint_with_global_batch_idx, ) +from .context_graph import ContextGraph, ContextState + from .decode import ( get_lattice, nbest_decoding, diff --git a/icefall/context_graph.py b/icefall/context_graph.py new file mode 100644 index 000000000..c78de30f6 --- /dev/null +++ b/icefall/context_graph.py @@ -0,0 +1,412 @@ +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from collections import deque +from typing import Dict, List, Optional, Tuple + + +class ContextState: + """The state in ContextGraph""" + + def __init__( + self, + id: int, + token: int, + token_score: float, + node_score: float, + local_node_score: float, + is_end: bool, + ): + """Create a ContextState. + + Args: + id: + The node id, only for visualization now. A node is in [0, graph.num_nodes). + The id of the root node is always 0. + token: + The token id. + score: + The bonus for each token during decoding, which will hopefully + boost the token up to survive beam search. + node_score: + The accumulated bonus from root of graph to current node, it will be + used to calculate the score for fail arc. + local_node_score: + The accumulated bonus from last ``end_node``(node with is_end true) + to current_node, it will be used to calculate the score for fail arc. + Node: The local_node_score of a ``end_node`` is 0. + is_end: + True if current token is the end of a context. + """ + self.id = id + self.token = token + self.token_score = token_score + self.node_score = node_score + self.local_node_score = local_node_score + self.is_end = is_end + self.next = {} + self.fail = None + self.output = None + + +class ContextGraph: + """The ContextGraph is modified from Aho-Corasick which is mainly + a Trie with a fail arc for each node. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details + of Aho-Corasick algorithm. + + A ContextGraph contains some words / phrases that we expect to boost their + scores during decoding. If the substring of a decoded sequence matches the word / phrase + in the ContextGraph, we will give the decoded sequence a bonus to make it survive + beam search. + """ + + def __init__(self, context_score: float): + """Initialize a ContextGraph with the given ``context_score``. + + A root node will be created (**NOTE:** the token of root is hardcoded to -1). + + Args: + context_score: + The bonus score for each token(note: NOT for each word/phrase, it means longer + word/phrase will have larger bonus score, they have to be matched though). + """ + self.context_score = context_score + self.num_nodes = 0 + self.root = ContextState( + id=self.num_nodes, + token=-1, + token_score=0, + node_score=0, + local_node_score=0, + is_end=False, + ) + self.root.fail = self.root + + def _fill_fail_output(self): + """This function fills the fail arc for each trie node, it can be computed + in linear time by performing a breadth-first search starting from the root. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the + details of the algorithm. + """ + queue = deque() + for token, node in self.root.next.items(): + node.fail = self.root + queue.append(node) + while queue: + current_node = queue.popleft() + for token, node in current_node.next.items(): + fail = current_node.fail + if token in fail.next: + fail = fail.next[token] + else: + fail = fail.fail + while token not in fail.next: + fail = fail.fail + if fail.token == -1: # root + break + if token in fail.next: + fail = fail.next[token] + node.fail = fail + # fill the output arc + output = node.fail + while not output.is_end: + output = output.fail + if output.token == -1: # root + output = None + break + node.output = output + queue.append(node) + + def build(self, token_ids: List[List[int]]): + """Build the ContextGraph from a list of token list. + It first build a trie from the given token lists, then fill the fail arc + for each trie node. + + See https://en.wikipedia.org/wiki/Trie for how to build a trie. + + Args: + token_ids: + The given token lists to build the ContextGraph, it is a list of token list, + each token list contains the token ids for a word/phrase. The token id + could be an id of a char (modeling with single Chinese char) or an id + of a BPE (modeling with BPEs). + """ + for tokens in token_ids: + node = self.root + for i, token in enumerate(tokens): + if token not in node.next: + self.num_nodes += 1 + is_end = i == len(tokens) - 1 + node.next[token] = ContextState( + id=self.num_nodes, + token=token, + token_score=self.context_score, + node_score=node.node_score + self.context_score, + local_node_score=0 + if is_end + else (node.local_node_score + self.context_score), + is_end=is_end, + ) + node = node.next[token] + self._fill_fail_output() + + def forward_one_step( + self, state: ContextState, token: int + ) -> Tuple[float, ContextState]: + """Search the graph with given state and token. + + Args: + state: + The given token containing trie node to start. + token: + The given token. + + Returns: + Return a tuple of score and next state. + """ + node = None + score = 0 + # token matched + if token in state.next: + node = state.next[token] + score = node.token_score + if state.is_end: + score += state.node_score + else: + # token not matched + # We will trace along the fail arc until it matches the token or reaching + # root of the graph. + node = state.fail + while token not in node.next: + node = node.fail + if node.token == -1: # root + break + + if token in node.next: + node = node.next[token] + + # The score of the fail path + score = node.node_score - state.local_node_score + assert node is not None + matched_score = 0 + output = node.output + while output is not None: + matched_score += output.node_score + output = output.output + return (score + matched_score, node) + + def finalize(self, state: ContextState) -> Tuple[float, ContextState]: + """When reaching the end of the decoded sequence, we need to finalize + the matching, the purpose is to subtract the added bonus score for the + state that is not the end of a word/phrase. + + Args: + state: + The given state(trie node). + + Returns: + Return a tuple of score and next state. If state is the end of a word/phrase + the score is zero, otherwise the score is the score of a implicit fail arc + to root. The next state is always root. + """ + # The score of the fail arc + score = -state.node_score + if state.is_end: + score = 0 + return (score, self.root) + + def draw( + self, + title: Optional[str] = None, + filename: Optional[str] = "", + symbol_table: Optional[Dict[int, str]] = None, + ) -> "Digraph": # noqa + + """Visualize a ContextGraph via graphviz. + + Render ContextGraph as an image via graphviz, and return the Digraph object; + and optionally save to file `filename`. + `filename` must have a suffix that graphviz understands, such as + `pdf`, `svg` or `png`. + + Note: + You need to install graphviz to use this function:: + + pip install graphviz + + Args: + title: + Title to be displayed in image, e.g. 'A simple FSA example' + filename: + Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg', + 'foo.png' (must have a suffix that graphviz understands). + symbol_table: + Map the token ids to symbols. + Returns: + A Diagraph from grahpviz. + """ + + try: + import graphviz + except Exception: + print("You cannot use `to_dot` unless the graphviz package is installed.") + raise + + graph_attr = { + "rankdir": "LR", + "size": "8.5,11", + "center": "1", + "orientation": "Portrait", + "ranksep": "0.4", + "nodesep": "0.25", + } + if title is not None: + graph_attr["label"] = title + + default_node_attr = { + "shape": "circle", + "style": "bold", + "fontsize": "14", + } + + final_state_attr = { + "shape": "doublecircle", + "style": "bold", + "fontsize": "14", + } + + final_state = -1 + dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr) + + seen = set() + queue = deque() + queue.append(self.root) + # root id is always 0 + dot.node("0", label="0", **default_node_attr) + dot.edge("0", "0", color="red") + seen.add(0) + + while len(queue): + current_node = queue.popleft() + for token, node in current_node.next.items(): + if node.id not in seen: + node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") + local_node_score = f"{node.local_node_score:.2f}".rstrip( + "0" + ).rstrip(".") + label = f"{node.id}/({node_score},{local_node_score})" + if node.is_end: + dot.node(str(node.id), label=label, **final_state_attr) + else: + dot.node(str(node.id), label=label, **default_node_attr) + seen.add(node.id) + weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".") + label = str(token) if symbol_table is None else symbol_table[token] + dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}") + dot.edge( + str(node.id), + str(node.fail.id), + color="red", + ) + if node.output is not None: + dot.edge( + str(node.id), + str(node.output.id), + color="green", + ) + queue.append(node) + + if filename: + _, extension = os.path.splitext(filename) + if extension == "" or extension[0] != ".": + raise ValueError( + "Filename needs to have a suffix like .png, .pdf, .svg: {}".format( + filename + ) + ) + + import tempfile + + with tempfile.TemporaryDirectory() as tmp_dir: + temp_fn = dot.render( + filename="temp", + directory=tmp_dir, + format=extension[1:], + cleanup=True, + ) + + shutil.move(temp_fn, filename) + + return dot + + +if __name__ == "__main__": + contexts_str = [ + "S", + "HE", + "SHE", + "SHELL", + "HIS", + "HERS", + "HELLO", + "THIS", + "THEM", + ] + contexts = [] + for s in contexts_str: + contexts.append([ord(x) for x in s]) + + context_graph = ContextGraph(context_score=1) + context_graph.build(contexts) + + symbol_table = {} + for contexts in contexts_str: + for s in contexts: + symbol_table[ord(s)] = s + + context_graph.draw( + title="Graph for: " + " / ".join(contexts_str), + filename="context_graph.pdf", + symbol_table=symbol_table, + ) + + queries = { + "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" + "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" + "HISHE": 9, # "HIS", "S", "SHE", "HE" + "SHED": 6, # "S", "SHE", "HE" + "HELL": 2, # "HE" + "HELLO": 7, # "HE", "HELLO" + "DHRHISQ": 4, # "HIS", "S" + "THEN": 2, # "HE" + } + for query, expected_score in queries.items(): + total_scores = 0 + state = context_graph.root + for q in query: + score, state = context_graph.forward_one_step(state, ord(q)) + total_scores += score + score, state = context_graph.finalize(state) + assert state.token == -1, state.token + total_scores += score + assert total_scores == expected_score, ( + total_scores, + expected_score, + query, + )