Implement Aho-Corasick context graph

This commit is contained in:
pkufool 2023-05-08 12:22:20 +08:00
parent 1bc55376b6
commit 2e7e7875f5
4 changed files with 223 additions and 258 deletions

View File

@ -768,9 +768,6 @@ class Hypothesis:
"""Return a string representation of self.ys""" """Return a string representation of self.ys"""
return "_".join(map(str, self.ys)) return "_".join(map(str, self.ys))
def __str__(self) -> str:
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
@ -919,7 +916,6 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None, context_graph: Optional[ContextGraph] = None,
num_context_history: int = 1,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
@ -971,7 +967,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history), context_state=None if context_graph is None else context_graph.root,
timestamp=[], timestamp=[],
) )
) )
@ -1056,12 +1052,15 @@ def modified_beam_search(
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:] new_timestamp = hyp.timestamp[:]
context_score = 0 context_score = 0
new_context_state = None if context_graph is None else hyp.context_state.clone() new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t) new_timestamp.append(t)
if context_graph is not None: if context_graph is not None:
context_score, new_context_state = hyp.context_state.forward_one_step(new_token) (
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
new_log_prob = topk_log_probs[k] + context_score new_log_prob = topk_log_probs[k] + context_score
@ -1081,7 +1080,9 @@ def modified_beam_search(
finalized_B = [HypothesisList() for _ in range(len(B))] finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B): for i, hyps in enumerate(B):
for hyp in list(hyps): for hyp in list(hyps):
context_score, new_context_state = hyp.context_state.finalize() context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add( finalized_B[i].add(
Hypothesis( Hypothesis(
ys=hyp.ys, ys=hyp.ys,

View File

@ -362,21 +362,20 @@ def get_parser():
"--context-score", "--context-score",
type=float, type=float,
default=2, default=2,
help="", help="""
) The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
parser.add_argument( """,
"--num-context-history",
type=int,
default=1,
help="",
) )
parser.add_argument( parser.add_argument(
"--context-file", "--context-file",
type=str, type=str,
default="", default="",
help="", help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -522,7 +521,6 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph, context_graph=context_graph,
num_context_history=params.num_context_history,
return_timestamps=True, return_timestamps=True,
) )
else: else:
@ -579,7 +577,6 @@ def decode_one_batch(
else: else:
key = f"beam_size_{params.beam_size}" key = f"beam_size_{params.beam_size}"
key += f"-context-score-{params.context_score}" key += f"-context-score-{params.context_score}"
key += f"-num-context-history-{params.num_context_history}"
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
@ -785,7 +782,6 @@ def main():
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}" params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -923,7 +919,7 @@ def main():
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append(line.strip()) contexts.append(line.strip())
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_bpe(contexts, sp) context_graph.build_context_graph(sp.encode(contexts))
else: else:
context_graph = None context_graph = None
else: else:

View File

@ -136,6 +136,7 @@ from beam_search import (
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -313,21 +314,20 @@ def get_parser():
"--context-score", "--context-score",
type=float, type=float,
default=2, default=2,
help="", help="""
) The bonus score of each token for the context biasing words/phrases.
Used only when --decoding_method is modified_beam_search.
parser.add_argument( """,
"--num-context-history",
type=int,
default=1,
help="",
) )
parser.add_argument( parser.add_argument(
"--context-file", "--context-file",
type=str, type=str,
default="", default="",
help="", help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding_method is modified_beam_search.
""",
) )
parser.add_argument( parser.add_argument(
@ -472,7 +472,6 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
context_graph=context_graph, context_graph=context_graph,
num_context_history=params.num_context_history,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -535,7 +534,7 @@ def decode_one_batch(
} }
else: else:
return { return {
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
} }
@ -685,7 +684,6 @@ def main():
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-context-score-{params.context_score}" params.suffix += f"-context-score-{params.context_score}"
params.suffix += f"-num-context-history-{params.num_context_history}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -715,11 +713,15 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
# import pdb; pdb.set_trace()
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
params.causal_convolution params.causal_convolution
@ -851,9 +853,9 @@ def main():
if os.path.exists(params.context_file): if os.path.exists(params.context_file):
contexts = [] contexts = []
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append(line.strip()) contexts.append(graph_compiler.texts_to_ids(line.strip()))
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph_char(contexts, lexicon.token_table) context_graph.build_context_graph(contexts)
else: else:
context_graph = None context_graph = None
else: else:

View File

@ -14,243 +14,209 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from heapq import heappush, heappop from typing import Dict, List, Tuple
import re
from dataclasses import dataclass
from typing import List, Tuple
import argparse
import k2
import kaldifst
import sentencepiece as spm
from icefall.utils import is_module_available
class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph_char(
self, contexts: List[str], token_table: k2.SymbolTable
):
"""Convert a list of texts to a list-of-list of token IDs.
Args:
contexts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
token_table:
The SymbolTable containing tokens and corresponding ids.
Returns:
Return a list-of-list of token IDs.
"""
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in contexts:
text = re.sub(whitespace, "", text)
sub_ids: List[int] = []
skip = False
for txt in text:
if txt not in token_table:
skip = True
break
sub_ids.append(token_table[txt])
if skip:
logging.warning(f"Skipping context {text}, as it has OOV char.")
continue
ids.append(sub_ids)
self.build_context_graph(ids)
def build_context_graph_bpe(
self, contexts: List[str], sp: spm.SentencePieceProcessor
):
contexts_bpe = sp.encode(contexts)
self.build_context_graph(contexts_bpe)
def build_context_graph(self, token_ids: List[List[int]]):
graph = kaldifst.StdVectorFst()
start_state = (
graph.add_state()
) # 1st state will be state 0 (returned by add_state)
assert start_state == 0, start_state
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
for tokens in token_ids:
prev_state = start_state
next_state = start_state
backoff_score = 0
for i in range(len(tokens)):
score = self.context_score
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=tokens[i],
olabel=tokens[i],
weight=score,
nextstate=next_state,
),
)
if i > 0:
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=-backoff_score,
nextstate=start_state,
),
)
prev_state = next_state
backoff_score += score
self.graph = kaldifst.determinize(graph)
kaldifst.arcsort(self.graph)
def is_final_state(self, state_id: int) -> bool:
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
num_arcs = self.graph.num_arcs(state_id)
# The LM is arc sorted by ilabel, so we use binary search below.
left = 0
right = num_arcs - 1
while left <= right:
mid = (left + right) // 2
arc_iter.seek(mid)
arc = arc_iter.value
if arc.ilabel < label:
left = mid + 1
elif arc.ilabel > label:
right = mid - 1
else:
return (arc.nextstate, arc.weight.value, True)
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
arc_iter.seek(0)
arc = arc_iter.value
if arc.ilabel == 0:
return (0, 0, False)
else:
return (0, arc.weight.value, False)
class ContextState: class ContextState:
def __init__(self, graph: ContextGraph, max_states: int): """The state in ContextGraph"""
self.graph = graph
self.max_states = max_states
# [(total score, (score, state_id))]
self.states: List[Tuple[float, Tuple[float, int]]] = []
def __str__(self): def __init__(self, token: int, score: float, total_score: float, is_end: bool):
return ";".join([str(state) for state in self.states]) """Create a ContextState.
def clone(self): Args:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states) token:
new_context_state.states = self.states[:] The token id.
return new_context_state score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
total_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
is_end:
True if current token is the end of a context.
"""
self.token = token
self.score = score
self.total_score = total_score
self.is_end = is_end
self.next = {}
self.fail = None
def finalize(self) -> float:
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
if len(self.states) == 0:
return 0, new_context_state
item = heappop(self.states)
return item[0], new_context_state
def forward_one_step(self, label: int) -> float: class ContextGraph:
states = self.states[:] """The ContextGraph is modified from Aho-Corasick which is mainly
new_states = [] a Trie with a fail arc for each node.
# expand current label from state state See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
status = self.graph.get_next_state(0, label) of Aho-Corasick algorithm.
if status[2]:
heappush(new_states, (-status[1], (status[1], status[0])))
else:
assert status[0] == 0 and status[2] == False, status
# the score we have added to the path till now A ContextGraph contains some words / phrases that we expect to boost their
prev_max_total_score = 0 scores during decoding. If the substring of a decoded sequence matches the word / phrase
# expand previous states with given label in the ContextGraph, we will give the decoded sequence a bonus to make it survive
while states: beam search.
state = heappop(states) """
if -state[0] > prev_max_total_score:
prev_max_total_score = -state[0]
status = self.graph.get_next_state(state[1][1], label) def __init__(self, context_score: float):
"""Initialize a ContextGraph with the given ``context_score``.
if status[2]: A root node will be created (**NOTE:** the token of root is hardcoded to -1).
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
else: Args:
pass context_score:
# assert status == (0, state[0], False), status The bonus score for each token(note: NOT for each word/phrase, it means longer
num_states_drop = ( word/phrase will have larger bonus score, they have to be matched though).
0 """
if len(new_states) <= self.max_states self.context_score = context_score
else len(new_states) - self.max_states self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
self.root.fail = self.root
def _fill_fail(self):
"""This function fills the fail arc for each trie node, it can be computed
in linear time by performing a breadth-first search starting from the root.
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
details of the algorithm.
"""
queue = []
for token, node in self.root.next.items():
node.fail = self.root
queue.append(node)
while queue:
current_node = queue.pop(0)
current_fail = current_node.fail
for token, node in current_node.next.items():
fail = current_fail
if token in current_fail.next:
fail = current_fail.next[token]
node.fail = fail
queue.append(node)
def build_context_graph(self, token_ids: List[List[int]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
"""
for tokens in token_ids:
node = self.root
for i, token in enumerate(tokens):
if token not in node.next:
node.next[token] = ContextState(
token=token,
score=self.context_score,
# The total score is the accumulated score from root to current node,
# it will be used to calculate the score of fail arc later.
total_score=node.total_score + self.context_score,
is_end=i == len(tokens) - 1,
) )
node = node.next[token]
self._fill_fail()
states = [] def forward_one_step(
if len(new_states) == 0: self, state: ContextState, token: int
new_context_state = ContextState(graph=self.graph, max_states=self.max_states) ) -> Tuple[float, ContextState]:
return -prev_max_total_score, new_context_state """Search the graph with given state and token.
item = heappop(new_states) Args:
state:
The given state (trie node) to start.
token:
The given token.
# if one item match a context, clear all states (means start a new context Returns:
# from next label), and return the score of current label Return a tuple of score and next state.
if self.graph.is_final_state(item[1][1]): """
new_context_state = ContextState(graph=self.graph, max_states=self.max_states) # token matched
return -item[0] - prev_max_total_score, new_context_state if token in state.next:
node = state.next[token]
score = node.score
# if the matched node is the end of a word/phrase, we will start
# from the root for next token.
if node.is_end:
node = self.root
return (score, node)
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
# root of the graph.
node = state.fail
while token not in node.next:
node = node.fail
if node.token == -1: # root
break
max_total_score = -item[0] if token in node.next:
heappush(states, item) node = node.next[token]
# The score of the fail arc
score = node.total_score - state.total_score
if node.is_end:
node = self.root
return (score, node)
while num_states_drop != 0: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
item = heappop(new_states) """When reaching the end of the decoded sequence, we need to finalize
if self.graph.is_final_state(item[1][1]): the matching, the purpose is to subtract the added bonus score for the
new_context_state = ContextState(graph=self.graph, max_states=self.max_states) state that is not the end of a word/phrase.
return -item[0] - prev_max_total_score, new_context_state
num_states_drop -= 1
while new_states: Args:
item = heappop(new_states) state:
if self.graph.is_final_state(item[1][1]): The given state(trie node).
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
return -item[0] - prev_max_total_score, new_context_state
heappush(states, item)
# no context matched, the matching may continue with previous prefix,
# or change to another prefix.
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
new_context_state.states = states
return max_total_score - prev_max_total_score, new_context_state
Returns:
Return a tuple of score and next state. If state is the end of a word/phrase
the score is zero, otherwise the score is the score of a implicit fail arc
to root. The next state is always root.
"""
# The score of the fail arc
score = self.root.total_score - state.total_score
if state.is_end:
score = 0
return (score, self.root)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() contexts_str = ["HE", "SHE", "HIS", "HERS"]
parser.add_argument( contexts = []
"--bpe_model", for s in contexts_str:
type=str, contexts.append([ord(x) for x in s])
help="Path to bpe model",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor() context_graph = ContextGraph(context_score=2)
sp.load(args.bpe_model) context_graph.build_context_graph(contexts)
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"] score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
context_graph = ContextGraph() assert score == 2, score
context_graph.build_context_graph_bpe(contexts, sp) assert state.token == ord("H"), state.token
if not is_module_available("graphviz"): score, state = context_graph.forward_one_step(state, ord("I"))
raise ValueError("Please 'pip install graphviz' first.") assert score == 2, score
import graphviz assert state.token == ord("I"), state.token
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True) score, state = context_graph.forward_one_step(state, ord("S"))
fst_source = graphviz.Source(fst_dot) assert score == 2, score
fst_source.render(outfile="context_graph.svg") assert state.token == -1, state.token
score, state = context_graph.finalize(state)
assert score == 0, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
assert score == 2, score
assert state.token == ord("S"), state.token
score, state = context_graph.forward_one_step(state, ord("H"))
assert score == 2, score
assert state.token == ord("H"), state.token
score, state = context_graph.forward_one_step(state, ord("D"))
assert score == -4, score
assert state.token == -1, state.token
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
assert score == 0, score
assert state.token == -1, state.token