Implement Aho-Corasick context graph

This commit is contained in:
pkufool 2023-05-08 12:22:20 +08:00
parent 1bc55376b6
commit 2e7e7875f5
4 changed files with 223 additions and 258 deletions

View File

@ -768,9 +768,6 @@ class Hypothesis:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
def __str__(self) -> str:
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
@ -919,7 +916,6 @@ def modified_beam_search(
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
num_context_history: int = 1,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
@ -971,7 +967,7 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history),
context_state=None if context_graph is None else context_graph.root,
timestamp=[],
)
)
@ -1056,12 +1052,15 @@ def modified_beam_search(
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state.clone()
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
if context_graph is not None:
context_score, new_context_state = hyp.context_state.forward_one_step(new_token)
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
new_log_prob = topk_log_probs[k] + context_score
@ -1081,7 +1080,9 @@ def modified_beam_search(
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = hyp.context_state.finalize()
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,

View File

@ -362,21 +362,20 @@ def get_parser():
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--num-context-history",
type=int,
default=1,
help="",
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
add_model_arguments(parser)
@ -522,7 +521,6 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
num_context_history=params.num_context_history,
return_timestamps=True,
)
else:
@ -579,7 +577,6 @@ def decode_one_batch(
else:
key = f"beam_size_{params.beam_size}"
key += f"-context-score-{params.context_score}"
key += f"-num-context-history-{params.num_context_history}"
return {key: (hyps, timestamps)}
@ -629,7 +626,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 1
log_interval = 1
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -785,7 +782,6 @@ def main():
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -923,7 +919,7 @@ def main():
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_bpe(contexts, sp)
context_graph.build_context_graph(sp.encode(contexts))
else:
context_graph = None
else:

View File

@ -136,6 +136,7 @@ from beam_search import (
from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -313,21 +314,20 @@ def get_parser():
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--num-context-history",
type=int,
default=1,
help="",
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
)
parser.add_argument(
@ -472,7 +472,6 @@ def decode_one_batch(
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
context_graph=context_graph,
num_context_history=params.num_context_history,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -535,7 +534,7 @@ def decode_one_batch(
}
else:
return {
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
}
@ -685,7 +684,6 @@ def main():
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -715,11 +713,15 @@ def main():
logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.simulate_streaming:
assert (
params.causal_convolution
@ -851,9 +853,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
contexts.append(graph_compiler.texts_to_ids(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_char(contexts, lexicon.token_table)
context_graph.build_context_graph(contexts)
else:
context_graph = None
else:

View File

@ -14,243 +14,209 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from heapq import heappush, heappop
import re
from dataclasses import dataclass
from typing import List, Tuple
import argparse
import k2
import kaldifst
import sentencepiece as spm
from icefall.utils import is_module_available
class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph_char(
self, contexts: List[str], token_table: k2.SymbolTable
):
"""Convert a list of texts to a list-of-list of token IDs.
Args:
contexts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
token_table:
The SymbolTable containing tokens and corresponding ids.
Returns:
Return a list-of-list of token IDs.
"""
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in contexts:
text = re.sub(whitespace, "", text)
sub_ids: List[int] = []
skip = False
for txt in text:
if txt not in token_table:
skip = True
break
sub_ids.append(token_table[txt])
if skip:
logging.warning(f"Skipping context {text}, as it has OOV char.")
continue
ids.append(sub_ids)
self.build_context_graph(ids)
def build_context_graph_bpe(
self, contexts: List[str], sp: spm.SentencePieceProcessor
):
contexts_bpe = sp.encode(contexts)
self.build_context_graph(contexts_bpe)
def build_context_graph(self, token_ids: List[List[int]]):
graph = kaldifst.StdVectorFst()
start_state = (
graph.add_state()
) # 1st state will be state 0 (returned by add_state)
assert start_state == 0, start_state
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
for tokens in token_ids:
prev_state = start_state
next_state = start_state
backoff_score = 0
for i in range(len(tokens)):
score = self.context_score
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=tokens[i],
olabel=tokens[i],
weight=score,
nextstate=next_state,
),
)
if i > 0:
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=-backoff_score,
nextstate=start_state,
),
)
prev_state = next_state
backoff_score += score
self.graph = kaldifst.determinize(graph)
kaldifst.arcsort(self.graph)
def is_final_state(self, state_id: int) -> bool:
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
num_arcs = self.graph.num_arcs(state_id)
# The LM is arc sorted by ilabel, so we use binary search below.
left = 0
right = num_arcs - 1
while left <= right:
mid = (left + right) // 2
arc_iter.seek(mid)
arc = arc_iter.value
if arc.ilabel < label:
left = mid + 1
elif arc.ilabel > label:
right = mid - 1
else:
return (arc.nextstate, arc.weight.value, True)
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
arc_iter.seek(0)
arc = arc_iter.value
if arc.ilabel == 0:
return (0, 0, False)
else:
return (0, arc.weight.value, False)
from typing import Dict, List, Tuple
class ContextState:
def __init__(self, graph: ContextGraph, max_states: int):
self.graph = graph
self.max_states = max_states
# [(total score, (score, state_id))]
self.states: List[Tuple[float, Tuple[float, int]]] = []
"""The state in ContextGraph"""
def __str__(self):
return ";".join([str(state) for state in self.states])
def __init__(self, token: int, score: float, total_score: float, is_end: bool):
"""Create a ContextState.
def clone(self):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
new_context_state.states = self.states[:]
return new_context_state
Args:
token:
The token id.
score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
total_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
is_end:
True if current token is the end of a context.
"""
self.token = token
self.score = score
self.total_score = total_score
self.is_end = is_end
self.next = {}
self.fail = None
def finalize(self) -> float:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
if len(self.states) == 0:
return 0, new_context_state
item = heappop(self.states)
return item[0], new_context_state
def forward_one_step(self, label: int) -> float:
states = self.states[:]
new_states = []
# expand current label from state state
status = self.graph.get_next_state(0, label)
if status[2]:
heappush(new_states, (-status[1], (status[1], status[0])))
class ContextGraph:
"""The ContextGraph is modified from Aho-Corasick which is mainly
a Trie with a fail arc for each node.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
of Aho-Corasick algorithm.
A ContextGraph contains some words / phrases that we expect to boost their
scores during decoding. If the substring of a decoded sequence matches the word / phrase
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
beam search.
"""
def __init__(self, context_score: float):
"""Initialize a ContextGraph with the given ``context_score``.
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
Args:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
"""
self.context_score = context_score
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
self.root.fail = self.root
def _fill_fail(self):
"""This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm.
"""
queue = []
for token, node in self.root.next.items():
node.fail = self.root
queue.append(node)
while queue:
current_node = queue.pop(0)
current_fail = current_node.fail
for token, node in current_node.next.items():
fail = current_fail
if token in current_fail.next:
fail = current_fail.next[token]
node.fail = fail
queue.append(node)
def build_context_graph(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
"""
for tokens in token_ids:
node = self.root
for i, token in enumerate(tokens):
if token not in node.next:
node.next[token] = ContextState(
token=token,
score=self.context_score,
# The total score is the accumulated score from root to current node,
# it will be used to calculate the score of fail arc later.
total_score=node.total_score + self.context_score,
is_end=i == len(tokens) - 1,
)
node = node.next[token]
self._fill_fail()
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
"""Search the graph with given state and token.
Args:
state:
The given state (trie node) to start.
token:
The given token.
Returns:
Return a tuple of score and next state.
"""
# token matched
if token in state.next:
node = state.next[token]
score = node.score
# if the matched node is the end of a word/phrase, we will start
# from the root for next token.
if node.is_end:
node = self.root
return (score, node)
else:
assert status[0] == 0 and status[2] == False, status
# token not matched
# We will trace along the fail arc until it matches the token or reaching
# root of the graph.
node = state.fail
while token not in node.next:
node = node.fail
if node.token == -1: # root
break
# the score we have added to the path till now
prev_max_total_score = 0
# expand previous states with given label
while states:
state = heappop(states)
if -state[0] > prev_max_total_score:
prev_max_total_score = -state[0]
if token in node.next:
node = node.next[token]
# The score of the fail arc
score = node.total_score - state.total_score
if node.is_end:
node = self.root
return (score, node)
status = self.graph.get_next_state(state[1][1], label)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
the matching, the purpose is to subtract the added bonus score for the
state that is not the end of a word/phrase.
if status[2]:
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
else:
pass
# assert status == (0, state[0], False), status
num_states_drop = (
0
if len(new_states) <= self.max_states
else len(new_states) - self.max_states
)
states = []
if len(new_states) == 0:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -prev_max_total_score, new_context_state
item = heappop(new_states)
# if one item match a context, clear all states (means start a new context
# from next label), and return the score of current label
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
max_total_score = -item[0]
heappush(states, item)
while num_states_drop != 0:
item = heappop(new_states)
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
num_states_drop -= 1
while new_states:
item = heappop(new_states)
if self.graph.is_final_state(item[1][1]):
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
heappush(states, item)
# no context matched, the matching may continue with previous prefix,
# or change to another prefix.
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
new_context_state.states = states
return max_total_score - prev_max_total_score, new_context_state
Args:
state:
The given state(trie node).
Returns:
Return a tuple of score and next state. If state is the end of a word/phrase
the score is zero, otherwise the score is the score of a implicit fail arc
to root. The next state is always root.
"""
# The score of the fail arc
score = self.root.total_score - state.total_score
if state.is_end:
score = 0
return (score, self.root)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe_model",
type=str,
help="Path to bpe model",
)
args = parser.parse_args()
contexts_str = ["HE", "SHE", "HIS", "HERS"]
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
context_graph = ContextGraph(context_score=2)
context_graph.build_context_graph(contexts)
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
context_graph = ContextGraph()
context_graph.build_context_graph_bpe(contexts, sp)
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
assert score == 2, score
assert state.token == ord("H"), state.token
if not is_module_available("graphviz"):
raise ValueError("Please 'pip install graphviz' first.")
import graphviz
score, state = context_graph.forward_one_step(state, ord("I"))
assert score == 2, score
assert state.token == ord("I"), state.token
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
fst_source = graphviz.Source(fst_dot)
fst_source.render(outfile="context_graph.svg")
score, state = context_graph.forward_one_step(state, ord("S"))
assert score == 2, score
assert state.token == -1, state.token
score, state = context_graph.finalize(state)
assert score == 0, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
assert score == 2, score
assert state.token == ord("S"), state.token
score, state = context_graph.forward_one_step(state, ord("H"))
assert score == 2, score
assert state.token == ord("H"), state.token
score, state = context_graph.forward_one_step(state, ord("D"))
assert score == -4, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
assert score == 0, score
assert state.token == -1, state.token