add custom score for each hotword

This commit is contained in:
pkufool 2023-11-10 22:24:57 +08:00
parent 6d275ddf9f
commit 8c5f5795d4

View File

@ -133,7 +133,7 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score
queue.append(node)
def build(self, token_ids: List[List[int]]):
def build(self, token_ids: List[Tuple[List[int], float]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@ -147,21 +147,31 @@ class ContextGraph:
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
"""
for tokens in token_ids:
for (tokens, score) in token_ids:
node = self.root
context_score = self.context_score if score == 0.0 else round(score / len(tokens), 2)
for i, token in enumerate(tokens):
node_next = {}
if token not in node.next:
self.num_nodes += 1
node_id = self.num_nodes
token_score = context_score
is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
else:
token_score = max(context_score, node.next[token].token_score)
node_id = node.next[token].id
node_next = node.next[token].next
is_end = i == len(tokens) - 1 or node.next[token].is_end
node_score = node.node_score + token_score
node.next[token] = ContextState(
id=node_id,
token=token,
token_score=token_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node.next[token].next = node_next
node = node.next[token]
self._fill_fail_output()
@ -343,7 +353,7 @@ class ContextGraph:
return dot
if __name__ == "__main__":
def _test(queries, score):
contexts_str = [
"S",
"HE",
@ -355,9 +365,11 @@ if __name__ == "__main__":
"THIS",
"THEM",
]
# test default score (1)
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
contexts.append(([ord(x) for x in s], score))
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
@ -369,10 +381,28 @@ if __name__ == "__main__":
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf",
filename=f"context_graph_{score}.pdf",
symbol_table=symbol_table,
)
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert round(total_scores, 2) == expected_score, (
total_scores,
expected_score,
query,
)
if __name__ == "__main__":
# test default score
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
@ -384,17 +414,27 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_score, (
total_scores,
expected_score,
query,
)
_test(queries, 0)
# test custom score (5)
# S : 5
# HE : 5 (2.5 + 2.5)
# SHE : 8.34 (5 + 1.67 + 1.67)
# SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
# HIS : 5.84 (2.5 + 1.67 + 1.67)
# HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
# HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
# THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
queries = {
"HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 24.18, # "HIS", "S", "SHE", "HE"
"SHED": 18.34, # "S", "SHE", "HE"
"SHELF": 18.34, # "S", "SHE", "HE"
"HELL": 5, # "HE"
"HELLO": 13, # "HE", "HELLO"
"DHRHISQ": 10.84, # "HIS", "S"
"THEN": 5, # "HE"
}
_test(queries, 5)