mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement Aho-Corasick context graph
This commit is contained in:
parent
1bc55376b6
commit
2e7e7875f5
@ -768,9 +768,6 @@ class Hypothesis:
|
|||||||
"""Return a string representation of self.ys"""
|
"""Return a string representation of self.ys"""
|
||||||
return "_".join(map(str, self.ys))
|
return "_".join(map(str, self.ys))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisList(object):
|
class HypothesisList(object):
|
||||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
@ -919,7 +916,6 @@ def modified_beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
context_graph: Optional[ContextGraph] = None,
|
context_graph: Optional[ContextGraph] = None,
|
||||||
num_context_history: int = 1,
|
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
@ -971,7 +967,7 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history),
|
context_state=None if context_graph is None else context_graph.root,
|
||||||
timestamp=[],
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1056,12 +1052,15 @@ def modified_beam_search(
|
|||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
new_timestamp = hyp.timestamp[:]
|
new_timestamp = hyp.timestamp[:]
|
||||||
context_score = 0
|
context_score = 0
|
||||||
new_context_state = None if context_graph is None else hyp.context_state.clone()
|
new_context_state = None if context_graph is None else hyp.context_state
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_timestamp.append(t)
|
new_timestamp.append(t)
|
||||||
if context_graph is not None:
|
if context_graph is not None:
|
||||||
context_score, new_context_state = hyp.context_state.forward_one_step(new_token)
|
(
|
||||||
|
context_score,
|
||||||
|
new_context_state,
|
||||||
|
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k] + context_score
|
new_log_prob = topk_log_probs[k] + context_score
|
||||||
|
|
||||||
@ -1081,7 +1080,9 @@ def modified_beam_search(
|
|||||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||||
for i, hyps in enumerate(B):
|
for i, hyps in enumerate(B):
|
||||||
for hyp in list(hyps):
|
for hyp in list(hyps):
|
||||||
context_score, new_context_state = hyp.context_state.finalize()
|
context_score, new_context_state = context_graph.finalize(
|
||||||
|
hyp.context_state
|
||||||
|
)
|
||||||
finalized_B[i].add(
|
finalized_B[i].add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=hyp.ys,
|
ys=hyp.ys,
|
||||||
|
|||||||
@ -362,21 +362,20 @@ def get_parser():
|
|||||||
"--context-score",
|
"--context-score",
|
||||||
type=float,
|
type=float,
|
||||||
default=2,
|
default=2,
|
||||||
help="",
|
help="""
|
||||||
)
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
parser.add_argument(
|
""",
|
||||||
"--num-context-history",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-file",
|
"--context-file",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="",
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -522,7 +521,6 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
num_context_history=params.num_context_history,
|
|
||||||
return_timestamps=True,
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -579,7 +577,6 @@ def decode_one_batch(
|
|||||||
else:
|
else:
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"beam_size_{params.beam_size}"
|
||||||
key += f"-context-score-{params.context_score}"
|
key += f"-context-score-{params.context_score}"
|
||||||
key += f"-num-context-history-{params.num_context_history}"
|
|
||||||
return {key: (hyps, timestamps)}
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
@ -785,7 +782,6 @@ def main():
|
|||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
params.suffix += f"-context-score-{params.context_score}"
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -923,7 +919,7 @@ def main():
|
|||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append(line.strip())
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build_context_graph_bpe(contexts, sp)
|
context_graph.build_context_graph(sp.encode(contexts))
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -136,6 +136,7 @@ from beam_search import (
|
|||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -313,21 +314,20 @@ def get_parser():
|
|||||||
"--context-score",
|
"--context-score",
|
||||||
type=float,
|
type=float,
|
||||||
default=2,
|
default=2,
|
||||||
help="",
|
help="""
|
||||||
)
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
parser.add_argument(
|
""",
|
||||||
"--num-context-history",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-file",
|
"--context-file",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="",
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -472,7 +472,6 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
num_context_history=params.num_context_history,
|
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
@ -535,7 +534,7 @@ def decode_one_batch(
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps
|
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -685,7 +684,6 @@ def main():
|
|||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
params.suffix += f"-context-score-{params.context_score}"
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -715,11 +713,15 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
params.causal_convolution
|
params.causal_convolution
|
||||||
@ -851,9 +853,9 @@ def main():
|
|||||||
if os.path.exists(params.context_file):
|
if os.path.exists(params.context_file):
|
||||||
contexts = []
|
contexts = []
|
||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append(graph_compiler.texts_to_ids(line.strip()))
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build_context_graph_char(contexts, lexicon.token_table)
|
context_graph.build_context_graph(contexts)
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -14,243 +14,209 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from heapq import heappush, heappop
|
from typing import Dict, List, Tuple
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Tuple
|
|
||||||
import argparse
|
|
||||||
import k2
|
|
||||||
import kaldifst
|
|
||||||
import sentencepiece as spm
|
|
||||||
|
|
||||||
from icefall.utils import is_module_available
|
|
||||||
|
|
||||||
|
|
||||||
class ContextGraph:
|
|
||||||
def __init__(self, context_score: float = 1):
|
|
||||||
self.context_score = context_score
|
|
||||||
|
|
||||||
def build_context_graph_char(
|
|
||||||
self, contexts: List[str], token_table: k2.SymbolTable
|
|
||||||
):
|
|
||||||
"""Convert a list of texts to a list-of-list of token IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
contexts:
|
|
||||||
It is a list of strings.
|
|
||||||
An example containing two strings is given below:
|
|
||||||
|
|
||||||
['你好中国', '北京欢迎您']
|
|
||||||
token_table:
|
|
||||||
The SymbolTable containing tokens and corresponding ids.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list-of-list of token IDs.
|
|
||||||
"""
|
|
||||||
ids: List[List[int]] = []
|
|
||||||
whitespace = re.compile(r"([ \t])")
|
|
||||||
for text in contexts:
|
|
||||||
text = re.sub(whitespace, "", text)
|
|
||||||
sub_ids: List[int] = []
|
|
||||||
skip = False
|
|
||||||
for txt in text:
|
|
||||||
if txt not in token_table:
|
|
||||||
skip = True
|
|
||||||
break
|
|
||||||
sub_ids.append(token_table[txt])
|
|
||||||
if skip:
|
|
||||||
logging.warning(f"Skipping context {text}, as it has OOV char.")
|
|
||||||
continue
|
|
||||||
ids.append(sub_ids)
|
|
||||||
self.build_context_graph(ids)
|
|
||||||
|
|
||||||
def build_context_graph_bpe(
|
|
||||||
self, contexts: List[str], sp: spm.SentencePieceProcessor
|
|
||||||
):
|
|
||||||
contexts_bpe = sp.encode(contexts)
|
|
||||||
self.build_context_graph(contexts_bpe)
|
|
||||||
|
|
||||||
def build_context_graph(self, token_ids: List[List[int]]):
|
|
||||||
graph = kaldifst.StdVectorFst()
|
|
||||||
start_state = (
|
|
||||||
graph.add_state()
|
|
||||||
) # 1st state will be state 0 (returned by add_state)
|
|
||||||
assert start_state == 0, start_state
|
|
||||||
graph.start = 0 # set the start state to 0
|
|
||||||
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
|
|
||||||
|
|
||||||
for tokens in token_ids:
|
|
||||||
prev_state = start_state
|
|
||||||
next_state = start_state
|
|
||||||
backoff_score = 0
|
|
||||||
for i in range(len(tokens)):
|
|
||||||
score = self.context_score
|
|
||||||
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
|
|
||||||
graph.add_arc(
|
|
||||||
state=prev_state,
|
|
||||||
arc=kaldifst.StdArc(
|
|
||||||
ilabel=tokens[i],
|
|
||||||
olabel=tokens[i],
|
|
||||||
weight=score,
|
|
||||||
nextstate=next_state,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if i > 0:
|
|
||||||
graph.add_arc(
|
|
||||||
state=prev_state,
|
|
||||||
arc=kaldifst.StdArc(
|
|
||||||
ilabel=0,
|
|
||||||
olabel=0,
|
|
||||||
weight=-backoff_score,
|
|
||||||
nextstate=start_state,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
prev_state = next_state
|
|
||||||
backoff_score += score
|
|
||||||
self.graph = kaldifst.determinize(graph)
|
|
||||||
kaldifst.arcsort(self.graph)
|
|
||||||
|
|
||||||
def is_final_state(self, state_id: int) -> bool:
|
|
||||||
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
|
|
||||||
|
|
||||||
|
|
||||||
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
|
|
||||||
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
|
|
||||||
num_arcs = self.graph.num_arcs(state_id)
|
|
||||||
|
|
||||||
# The LM is arc sorted by ilabel, so we use binary search below.
|
|
||||||
left = 0
|
|
||||||
right = num_arcs - 1
|
|
||||||
while left <= right:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
arc_iter.seek(mid)
|
|
||||||
arc = arc_iter.value
|
|
||||||
if arc.ilabel < label:
|
|
||||||
left = mid + 1
|
|
||||||
elif arc.ilabel > label:
|
|
||||||
right = mid - 1
|
|
||||||
else:
|
|
||||||
return (arc.nextstate, arc.weight.value, True)
|
|
||||||
|
|
||||||
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
|
|
||||||
arc_iter.seek(0)
|
|
||||||
arc = arc_iter.value
|
|
||||||
if arc.ilabel == 0:
|
|
||||||
return (0, 0, False)
|
|
||||||
else:
|
|
||||||
return (0, arc.weight.value, False)
|
|
||||||
|
|
||||||
|
|
||||||
class ContextState:
|
class ContextState:
|
||||||
def __init__(self, graph: ContextGraph, max_states: int):
|
"""The state in ContextGraph"""
|
||||||
self.graph = graph
|
|
||||||
self.max_states = max_states
|
|
||||||
# [(total score, (score, state_id))]
|
|
||||||
self.states: List[Tuple[float, Tuple[float, int]]] = []
|
|
||||||
|
|
||||||
def __str__(self):
|
def __init__(self, token: int, score: float, total_score: float, is_end: bool):
|
||||||
return ";".join([str(state) for state in self.states])
|
"""Create a ContextState.
|
||||||
|
|
||||||
def clone(self):
|
Args:
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
token:
|
||||||
new_context_state.states = self.states[:]
|
The token id.
|
||||||
return new_context_state
|
score:
|
||||||
|
The bonus for each token during decoding, which will hopefully
|
||||||
|
boost the token up to survive beam search.
|
||||||
|
total_score:
|
||||||
|
The accumulated bonus from root of graph to current node, it will be
|
||||||
|
used to calculate the score for fail arc.
|
||||||
|
is_end:
|
||||||
|
True if current token is the end of a context.
|
||||||
|
"""
|
||||||
|
self.token = token
|
||||||
|
self.score = score
|
||||||
|
self.total_score = total_score
|
||||||
|
self.is_end = is_end
|
||||||
|
self.next = {}
|
||||||
|
self.fail = None
|
||||||
|
|
||||||
def finalize(self) -> float:
|
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
|
||||||
if len(self.states) == 0:
|
|
||||||
return 0, new_context_state
|
|
||||||
item = heappop(self.states)
|
|
||||||
return item[0], new_context_state
|
|
||||||
|
|
||||||
def forward_one_step(self, label: int) -> float:
|
class ContextGraph:
|
||||||
states = self.states[:]
|
"""The ContextGraph is modified from Aho-Corasick which is mainly
|
||||||
new_states = []
|
a Trie with a fail arc for each node.
|
||||||
# expand current label from state state
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
|
||||||
status = self.graph.get_next_state(0, label)
|
of Aho-Corasick algorithm.
|
||||||
if status[2]:
|
|
||||||
heappush(new_states, (-status[1], (status[1], status[0])))
|
|
||||||
else:
|
|
||||||
assert status[0] == 0 and status[2] == False, status
|
|
||||||
|
|
||||||
# the score we have added to the path till now
|
A ContextGraph contains some words / phrases that we expect to boost their
|
||||||
prev_max_total_score = 0
|
scores during decoding. If the substring of a decoded sequence matches the word / phrase
|
||||||
# expand previous states with given label
|
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
|
||||||
while states:
|
beam search.
|
||||||
state = heappop(states)
|
"""
|
||||||
if -state[0] > prev_max_total_score:
|
|
||||||
prev_max_total_score = -state[0]
|
|
||||||
|
|
||||||
status = self.graph.get_next_state(state[1][1], label)
|
def __init__(self, context_score: float):
|
||||||
|
"""Initialize a ContextGraph with the given ``context_score``.
|
||||||
|
|
||||||
if status[2]:
|
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||||
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
|
|
||||||
else:
|
Args:
|
||||||
pass
|
context_score:
|
||||||
# assert status == (0, state[0], False), status
|
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||||
num_states_drop = (
|
word/phrase will have larger bonus score, they have to be matched though).
|
||||||
0
|
"""
|
||||||
if len(new_states) <= self.max_states
|
self.context_score = context_score
|
||||||
else len(new_states) - self.max_states
|
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
|
||||||
|
self.root.fail = self.root
|
||||||
|
|
||||||
|
def _fill_fail(self):
|
||||||
|
"""This function fills the fail arc for each trie node, it can be computed
|
||||||
|
in linear time by performing a breadth-first search starting from the root.
|
||||||
|
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||||
|
details of the algorithm.
|
||||||
|
"""
|
||||||
|
queue = []
|
||||||
|
for token, node in self.root.next.items():
|
||||||
|
node.fail = self.root
|
||||||
|
queue.append(node)
|
||||||
|
while queue:
|
||||||
|
current_node = queue.pop(0)
|
||||||
|
current_fail = current_node.fail
|
||||||
|
for token, node in current_node.next.items():
|
||||||
|
fail = current_fail
|
||||||
|
if token in current_fail.next:
|
||||||
|
fail = current_fail.next[token]
|
||||||
|
node.fail = fail
|
||||||
|
queue.append(node)
|
||||||
|
|
||||||
|
def build_context_graph(self, token_ids: List[List[int]]):
|
||||||
|
"""Build the ContextGraph from a list of token list.
|
||||||
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
|
for each trie node.
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids:
|
||||||
|
The given token lists to build the ContextGraph, it is a list of token list,
|
||||||
|
each token list contains the token ids for a word/phrase. The token id
|
||||||
|
could be an id of a char (modeling with single Chinese char) or an id
|
||||||
|
of a BPE (modeling with BPEs).
|
||||||
|
"""
|
||||||
|
for tokens in token_ids:
|
||||||
|
node = self.root
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
if token not in node.next:
|
||||||
|
node.next[token] = ContextState(
|
||||||
|
token=token,
|
||||||
|
score=self.context_score,
|
||||||
|
# The total score is the accumulated score from root to current node,
|
||||||
|
# it will be used to calculate the score of fail arc later.
|
||||||
|
total_score=node.total_score + self.context_score,
|
||||||
|
is_end=i == len(tokens) - 1,
|
||||||
)
|
)
|
||||||
|
node = node.next[token]
|
||||||
|
self._fill_fail()
|
||||||
|
|
||||||
states = []
|
def forward_one_step(
|
||||||
if len(new_states) == 0:
|
self, state: ContextState, token: int
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
) -> Tuple[float, ContextState]:
|
||||||
return -prev_max_total_score, new_context_state
|
"""Search the graph with given state and token.
|
||||||
|
|
||||||
item = heappop(new_states)
|
Args:
|
||||||
|
state:
|
||||||
|
The given state (trie node) to start.
|
||||||
|
token:
|
||||||
|
The given token.
|
||||||
|
|
||||||
# if one item match a context, clear all states (means start a new context
|
Returns:
|
||||||
# from next label), and return the score of current label
|
Return a tuple of score and next state.
|
||||||
if self.graph.is_final_state(item[1][1]):
|
"""
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
# token matched
|
||||||
return -item[0] - prev_max_total_score, new_context_state
|
if token in state.next:
|
||||||
|
node = state.next[token]
|
||||||
|
score = node.score
|
||||||
|
# if the matched node is the end of a word/phrase, we will start
|
||||||
|
# from the root for next token.
|
||||||
|
if node.is_end:
|
||||||
|
node = self.root
|
||||||
|
return (score, node)
|
||||||
|
else:
|
||||||
|
# token not matched
|
||||||
|
# We will trace along the fail arc until it matches the token or reaching
|
||||||
|
# root of the graph.
|
||||||
|
node = state.fail
|
||||||
|
while token not in node.next:
|
||||||
|
node = node.fail
|
||||||
|
if node.token == -1: # root
|
||||||
|
break
|
||||||
|
|
||||||
max_total_score = -item[0]
|
if token in node.next:
|
||||||
heappush(states, item)
|
node = node.next[token]
|
||||||
|
# The score of the fail arc
|
||||||
|
score = node.total_score - state.total_score
|
||||||
|
if node.is_end:
|
||||||
|
node = self.root
|
||||||
|
return (score, node)
|
||||||
|
|
||||||
while num_states_drop != 0:
|
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||||
item = heappop(new_states)
|
"""When reaching the end of the decoded sequence, we need to finalize
|
||||||
if self.graph.is_final_state(item[1][1]):
|
the matching, the purpose is to subtract the added bonus score for the
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
state that is not the end of a word/phrase.
|
||||||
return -item[0] - prev_max_total_score, new_context_state
|
|
||||||
num_states_drop -= 1
|
|
||||||
|
|
||||||
while new_states:
|
Args:
|
||||||
item = heappop(new_states)
|
state:
|
||||||
if self.graph.is_final_state(item[1][1]):
|
The given state(trie node).
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
|
||||||
return -item[0] - prev_max_total_score, new_context_state
|
|
||||||
heappush(states, item)
|
|
||||||
# no context matched, the matching may continue with previous prefix,
|
|
||||||
# or change to another prefix.
|
|
||||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
|
||||||
new_context_state.states = states
|
|
||||||
return max_total_score - prev_max_total_score, new_context_state
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple of score and next state. If state is the end of a word/phrase
|
||||||
|
the score is zero, otherwise the score is the score of a implicit fail arc
|
||||||
|
to root. The next state is always root.
|
||||||
|
"""
|
||||||
|
# The score of the fail arc
|
||||||
|
score = self.root.total_score - state.total_score
|
||||||
|
if state.is_end:
|
||||||
|
score = 0
|
||||||
|
return (score, self.root)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
contexts_str = ["HE", "SHE", "HIS", "HERS"]
|
||||||
parser.add_argument(
|
contexts = []
|
||||||
"--bpe_model",
|
for s in contexts_str:
|
||||||
type=str,
|
contexts.append([ord(x) for x in s])
|
||||||
help="Path to bpe model",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
context_graph = ContextGraph(context_score=2)
|
||||||
sp.load(args.bpe_model)
|
context_graph.build_context_graph(contexts)
|
||||||
|
|
||||||
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
|
||||||
context_graph = ContextGraph()
|
assert score == 2, score
|
||||||
context_graph.build_context_graph_bpe(contexts, sp)
|
assert state.token == ord("H"), state.token
|
||||||
|
|
||||||
if not is_module_available("graphviz"):
|
score, state = context_graph.forward_one_step(state, ord("I"))
|
||||||
raise ValueError("Please 'pip install graphviz' first.")
|
assert score == 2, score
|
||||||
import graphviz
|
assert state.token == ord("I"), state.token
|
||||||
|
|
||||||
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
|
score, state = context_graph.forward_one_step(state, ord("S"))
|
||||||
fst_source = graphviz.Source(fst_dot)
|
assert score == 2, score
|
||||||
fst_source.render(outfile="context_graph.svg")
|
assert state.token == -1, state.token
|
||||||
|
|
||||||
|
score, state = context_graph.finalize(state)
|
||||||
|
assert score == 0, score
|
||||||
|
assert state.token == -1, state.token
|
||||||
|
|
||||||
|
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
|
||||||
|
assert score == 2, score
|
||||||
|
assert state.token == ord("S"), state.token
|
||||||
|
|
||||||
|
score, state = context_graph.forward_one_step(state, ord("H"))
|
||||||
|
assert score == 2, score
|
||||||
|
assert state.token == ord("H"), state.token
|
||||||
|
|
||||||
|
score, state = context_graph.forward_one_step(state, ord("D"))
|
||||||
|
assert score == -4, score
|
||||||
|
assert state.token == -1, state.token
|
||||||
|
|
||||||
|
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
|
||||||
|
assert score == 0, score
|
||||||
|
assert state.token == -1, state.token
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user