mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement Aho-Corasick context graph
This commit is contained in:
parent
1bc55376b6
commit
2e7e7875f5
@ -768,9 +768,6 @@ class Hypothesis:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ys: {'_'.join([str(i) for i in self.ys])}, log_prob: {float(self.log_prob):.2f}, state: {self.context_state}"
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
@ -919,7 +916,6 @@ def modified_beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
num_context_history: int = 1,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
@ -971,7 +967,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else ContextState(graph=context_graph, max_states=num_context_history),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -1056,12 +1052,15 @@ def modified_beam_search(
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state.clone()
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
if context_graph is not None:
|
||||
context_score, new_context_state = hyp.context_state.forward_one_step(new_token)
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
new_log_prob = topk_log_probs[k] + context_score
|
||||
|
||||
@ -1081,7 +1080,9 @@ def modified_beam_search(
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = hyp.context_state.finalize()
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
|
||||
@ -362,21 +362,20 @@ def get_parser():
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-context-history",
|
||||
type=int,
|
||||
default=1,
|
||||
help="",
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -522,7 +521,6 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
num_context_history=params.num_context_history,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -579,7 +577,6 @@ def decode_one_batch(
|
||||
else:
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key += f"-context-score-{params.context_score}"
|
||||
key += f"-num-context-history-{params.num_context_history}"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
|
||||
@ -629,7 +626,7 @@ def decode_dataset(
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 1
|
||||
log_interval = 1
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -785,7 +782,6 @@ def main():
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -923,7 +919,7 @@ def main():
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build_context_graph_bpe(contexts, sp)
|
||||
context_graph.build_context_graph(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
||||
@ -136,6 +136,7 @@ from beam_search import (
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -313,21 +314,20 @@ def get_parser():
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-context-history",
|
||||
type=int,
|
||||
default=1,
|
||||
help="",
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -472,7 +472,6 @@ def decode_one_batch(
|
||||
beam=params.beam_size,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_graph=context_graph,
|
||||
num_context_history=params.num_context_history,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -535,7 +534,7 @@ def decode_one_batch(
|
||||
}
|
||||
else:
|
||||
return {
|
||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}_num_context_history_{params.num_context_history}": hyps
|
||||
f"beam_size_{params.beam_size}_context_score_{params.context_score}": hyps
|
||||
}
|
||||
|
||||
|
||||
@ -685,7 +684,6 @@ def main():
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
params.suffix += f"-num-context-history-{params.num_context_history}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -715,11 +713,15 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
@ -851,9 +853,9 @@ def main():
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
contexts.append(graph_compiler.texts_to_ids(line.strip()))
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build_context_graph_char(contexts, lexicon.token_table)
|
||||
context_graph.build_context_graph(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
||||
@ -14,243 +14,209 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from heapq import heappush, heappop
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
import argparse
|
||||
import k2
|
||||
import kaldifst
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.utils import is_module_available
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
def __init__(self, context_score: float = 1):
|
||||
self.context_score = context_score
|
||||
|
||||
def build_context_graph_char(
|
||||
self, contexts: List[str], token_table: k2.SymbolTable
|
||||
):
|
||||
"""Convert a list of texts to a list-of-list of token IDs.
|
||||
|
||||
Args:
|
||||
contexts:
|
||||
It is a list of strings.
|
||||
An example containing two strings is given below:
|
||||
|
||||
['你好中国', '北京欢迎您']
|
||||
token_table:
|
||||
The SymbolTable containing tokens and corresponding ids.
|
||||
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
ids: List[List[int]] = []
|
||||
whitespace = re.compile(r"([ \t])")
|
||||
for text in contexts:
|
||||
text = re.sub(whitespace, "", text)
|
||||
sub_ids: List[int] = []
|
||||
skip = False
|
||||
for txt in text:
|
||||
if txt not in token_table:
|
||||
skip = True
|
||||
break
|
||||
sub_ids.append(token_table[txt])
|
||||
if skip:
|
||||
logging.warning(f"Skipping context {text}, as it has OOV char.")
|
||||
continue
|
||||
ids.append(sub_ids)
|
||||
self.build_context_graph(ids)
|
||||
|
||||
def build_context_graph_bpe(
|
||||
self, contexts: List[str], sp: spm.SentencePieceProcessor
|
||||
):
|
||||
contexts_bpe = sp.encode(contexts)
|
||||
self.build_context_graph(contexts_bpe)
|
||||
|
||||
def build_context_graph(self, token_ids: List[List[int]]):
|
||||
graph = kaldifst.StdVectorFst()
|
||||
start_state = (
|
||||
graph.add_state()
|
||||
) # 1st state will be state 0 (returned by add_state)
|
||||
assert start_state == 0, start_state
|
||||
graph.start = 0 # set the start state to 0
|
||||
graph.set_final(start_state, weight=kaldifst.TropicalWeight.one)
|
||||
|
||||
for tokens in token_ids:
|
||||
prev_state = start_state
|
||||
next_state = start_state
|
||||
backoff_score = 0
|
||||
for i in range(len(tokens)):
|
||||
score = self.context_score
|
||||
next_state = graph.add_state() if i < len(tokens) - 1 else start_state
|
||||
graph.add_arc(
|
||||
state=prev_state,
|
||||
arc=kaldifst.StdArc(
|
||||
ilabel=tokens[i],
|
||||
olabel=tokens[i],
|
||||
weight=score,
|
||||
nextstate=next_state,
|
||||
),
|
||||
)
|
||||
if i > 0:
|
||||
graph.add_arc(
|
||||
state=prev_state,
|
||||
arc=kaldifst.StdArc(
|
||||
ilabel=0,
|
||||
olabel=0,
|
||||
weight=-backoff_score,
|
||||
nextstate=start_state,
|
||||
),
|
||||
)
|
||||
prev_state = next_state
|
||||
backoff_score += score
|
||||
self.graph = kaldifst.determinize(graph)
|
||||
kaldifst.arcsort(self.graph)
|
||||
|
||||
def is_final_state(self, state_id: int) -> bool:
|
||||
return self.graph.final(state_id) == kaldifst.TropicalWeight.one
|
||||
|
||||
|
||||
def get_next_state(self, state_id: int, label: int) -> Tuple[int, float, bool]:
|
||||
arc_iter = kaldifst.ArcIterator(self.graph, state_id)
|
||||
num_arcs = self.graph.num_arcs(state_id)
|
||||
|
||||
# The LM is arc sorted by ilabel, so we use binary search below.
|
||||
left = 0
|
||||
right = num_arcs - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
arc_iter.seek(mid)
|
||||
arc = arc_iter.value
|
||||
if arc.ilabel < label:
|
||||
left = mid + 1
|
||||
elif arc.ilabel > label:
|
||||
right = mid - 1
|
||||
else:
|
||||
return (arc.nextstate, arc.weight.value, True)
|
||||
|
||||
# Backoff to state 0 with the score on epsilon arc (ilabel == 0)
|
||||
arc_iter.seek(0)
|
||||
arc = arc_iter.value
|
||||
if arc.ilabel == 0:
|
||||
return (0, 0, False)
|
||||
else:
|
||||
return (0, arc.weight.value, False)
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class ContextState:
|
||||
def __init__(self, graph: ContextGraph, max_states: int):
|
||||
self.graph = graph
|
||||
self.max_states = max_states
|
||||
# [(total score, (score, state_id))]
|
||||
self.states: List[Tuple[float, Tuple[float, int]]] = []
|
||||
"""The state in ContextGraph"""
|
||||
|
||||
def __str__(self):
|
||||
return ";".join([str(state) for state in self.states])
|
||||
def __init__(self, token: int, score: float, total_score: float, is_end: bool):
|
||||
"""Create a ContextState.
|
||||
|
||||
def clone(self):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
new_context_state.states = self.states[:]
|
||||
return new_context_state
|
||||
Args:
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
total_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
self.token = token
|
||||
self.score = score
|
||||
self.total_score = total_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
|
||||
def finalize(self) -> float:
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
if len(self.states) == 0:
|
||||
return 0, new_context_state
|
||||
item = heappop(self.states)
|
||||
return item[0], new_context_state
|
||||
|
||||
def forward_one_step(self, label: int) -> float:
|
||||
states = self.states[:]
|
||||
new_states = []
|
||||
# expand current label from state state
|
||||
status = self.graph.get_next_state(0, label)
|
||||
if status[2]:
|
||||
heappush(new_states, (-status[1], (status[1], status[0])))
|
||||
class ContextGraph:
|
||||
"""The ContextGraph is modified from Aho-Corasick which is mainly
|
||||
a Trie with a fail arc for each node.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details
|
||||
of Aho-Corasick algorithm.
|
||||
|
||||
A ContextGraph contains some words / phrases that we expect to boost their
|
||||
scores during decoding. If the substring of a decoded sequence matches the word / phrase
|
||||
in the ContextGraph, we will give the decoded sequence a bonus to make it survive
|
||||
beam search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_score: float):
|
||||
"""Initialize a ContextGraph with the given ``context_score``.
|
||||
|
||||
A root node will be created (**NOTE:** the token of root is hardcoded to -1).
|
||||
|
||||
Args:
|
||||
context_score:
|
||||
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||
word/phrase will have larger bonus score, they have to be matched though).
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.root = ContextState(token=-1, score=0, total_score=0, is_end=False)
|
||||
self.root.fail = self.root
|
||||
|
||||
def _fill_fail(self):
|
||||
"""This function fills the fail arc for each trie node, it can be computed
|
||||
in linear time by performing a breadth-first search starting from the root.
|
||||
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the
|
||||
details of the algorithm.
|
||||
"""
|
||||
queue = []
|
||||
for token, node in self.root.next.items():
|
||||
node.fail = self.root
|
||||
queue.append(node)
|
||||
while queue:
|
||||
current_node = queue.pop(0)
|
||||
current_fail = current_node.fail
|
||||
for token, node in current_node.next.items():
|
||||
fail = current_fail
|
||||
if token in current_fail.next:
|
||||
fail = current_fail.next[token]
|
||||
node.fail = fail
|
||||
queue.append(node)
|
||||
|
||||
def build_context_graph(self, token_ids: List[List[int]]):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Trie for how to build a trie.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
The given token lists to build the ContextGraph, it is a list of token list,
|
||||
each token list contains the token ids for a word/phrase. The token id
|
||||
could be an id of a char (modeling with single Chinese char) or an id
|
||||
of a BPE (modeling with BPEs).
|
||||
"""
|
||||
for tokens in token_ids:
|
||||
node = self.root
|
||||
for i, token in enumerate(tokens):
|
||||
if token not in node.next:
|
||||
node.next[token] = ContextState(
|
||||
token=token,
|
||||
score=self.context_score,
|
||||
# The total score is the accumulated score from root to current node,
|
||||
# it will be used to calculate the score of fail arc later.
|
||||
total_score=node.total_score + self.context_score,
|
||||
is_end=i == len(tokens) - 1,
|
||||
)
|
||||
node = node.next[token]
|
||||
self._fill_fail()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int
|
||||
) -> Tuple[float, ContextState]:
|
||||
"""Search the graph with given state and token.
|
||||
|
||||
Args:
|
||||
state:
|
||||
The given state (trie node) to start.
|
||||
token:
|
||||
The given token.
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state.
|
||||
"""
|
||||
# token matched
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.score
|
||||
# if the matched node is the end of a word/phrase, we will start
|
||||
# from the root for next token.
|
||||
if node.is_end:
|
||||
node = self.root
|
||||
return (score, node)
|
||||
else:
|
||||
assert status[0] == 0 and status[2] == False, status
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
# root of the graph.
|
||||
node = state.fail
|
||||
while token not in node.next:
|
||||
node = node.fail
|
||||
if node.token == -1: # root
|
||||
break
|
||||
|
||||
# the score we have added to the path till now
|
||||
prev_max_total_score = 0
|
||||
# expand previous states with given label
|
||||
while states:
|
||||
state = heappop(states)
|
||||
if -state[0] > prev_max_total_score:
|
||||
prev_max_total_score = -state[0]
|
||||
if token in node.next:
|
||||
node = node.next[token]
|
||||
# The score of the fail arc
|
||||
score = node.total_score - state.total_score
|
||||
if node.is_end:
|
||||
node = self.root
|
||||
return (score, node)
|
||||
|
||||
status = self.graph.get_next_state(state[1][1], label)
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
the matching, the purpose is to subtract the added bonus score for the
|
||||
state that is not the end of a word/phrase.
|
||||
|
||||
if status[2]:
|
||||
heappush(new_states, (state[0] - status[1], (status[1], status[0])))
|
||||
else:
|
||||
pass
|
||||
# assert status == (0, state[0], False), status
|
||||
num_states_drop = (
|
||||
0
|
||||
if len(new_states) <= self.max_states
|
||||
else len(new_states) - self.max_states
|
||||
)
|
||||
|
||||
states = []
|
||||
if len(new_states) == 0:
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -prev_max_total_score, new_context_state
|
||||
|
||||
item = heappop(new_states)
|
||||
|
||||
# if one item match a context, clear all states (means start a new context
|
||||
# from next label), and return the score of current label
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
|
||||
max_total_score = -item[0]
|
||||
heappush(states, item)
|
||||
|
||||
while num_states_drop != 0:
|
||||
item = heappop(new_states)
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
num_states_drop -= 1
|
||||
|
||||
while new_states:
|
||||
item = heappop(new_states)
|
||||
if self.graph.is_final_state(item[1][1]):
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
return -item[0] - prev_max_total_score, new_context_state
|
||||
heappush(states, item)
|
||||
# no context matched, the matching may continue with previous prefix,
|
||||
# or change to another prefix.
|
||||
new_context_state = ContextState(graph=self.graph, max_states=self.max_states)
|
||||
new_context_state.states = states
|
||||
return max_total_score - prev_max_total_score, new_context_state
|
||||
Args:
|
||||
state:
|
||||
The given state(trie node).
|
||||
|
||||
Returns:
|
||||
Return a tuple of score and next state. If state is the end of a word/phrase
|
||||
the score is zero, otherwise the score is the score of a implicit fail arc
|
||||
to root. The next state is always root.
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = self.root.total_score - state.total_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bpe_model",
|
||||
type=str,
|
||||
help="Path to bpe model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
contexts_str = ["HE", "SHE", "HIS", "HERS"]
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
context_graph = ContextGraph(context_score=2)
|
||||
context_graph.build_context_graph(contexts)
|
||||
|
||||
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
||||
context_graph = ContextGraph()
|
||||
context_graph.build_context_graph_bpe(contexts, sp)
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("H"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("H"), state.token
|
||||
|
||||
if not is_module_available("graphviz"):
|
||||
raise ValueError("Please 'pip install graphviz' first.")
|
||||
import graphviz
|
||||
score, state = context_graph.forward_one_step(state, ord("I"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("I"), state.token
|
||||
|
||||
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
|
||||
fst_source = graphviz.Source(fst_dot)
|
||||
fst_source.render(outfile="context_graph.svg")
|
||||
score, state = context_graph.forward_one_step(state, ord("S"))
|
||||
assert score == 2, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.finalize(state)
|
||||
assert score == 0, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("S"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("S"), state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("H"))
|
||||
assert score == 2, score
|
||||
assert state.token == ord("H"), state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(state, ord("D"))
|
||||
assert score == -4, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
score, state = context_graph.forward_one_step(context_graph.root, ord("D"))
|
||||
assert score == 0, score
|
||||
assert state.token == -1, state.token
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user