add custom score for each hotword

This commit is contained in:
pkufool 2023-11-10 22:24:57 +08:00
parent 6d275ddf9f
commit 8c5f5795d4

View File

@ -133,7 +133,7 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score node.output_score += 0 if output is None else output.output_score
queue.append(node) queue.append(node)
def build(self, token_ids: List[List[int]]): def build(self, token_ids: List[Tuple[List[int], float]]):
"""Build the ContextGraph from a list of token list. """Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc It first build a trie from the given token lists, then fill the fail arc
for each trie node. for each trie node.
@ -147,21 +147,31 @@ class ContextGraph:
could be an id of a char (modeling with single Chinese char) or an id could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs). of a BPE (modeling with BPEs).
""" """
for tokens in token_ids: for (tokens, score) in token_ids:
node = self.root node = self.root
context_score = self.context_score if score == 0.0 else round(score / len(tokens), 2)
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
node_next = {}
if token not in node.next: if token not in node.next:
self.num_nodes += 1 self.num_nodes += 1
node_id = self.num_nodes
token_score = context_score
is_end = i == len(tokens) - 1 is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score else:
node.next[token] = ContextState( token_score = max(context_score, node.next[token].token_score)
id=self.num_nodes, node_id = node.next[token].id
token=token, node_next = node.next[token].next
token_score=self.context_score, is_end = i == len(tokens) - 1 or node.next[token].is_end
node_score=node_score, node_score = node.node_score + token_score
output_score=node_score if is_end else 0, node.next[token] = ContextState(
is_end=is_end, id=node_id,
) token=token,
token_score=token_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node.next[token].next = node_next
node = node.next[token] node = node.next[token]
self._fill_fail_output() self._fill_fail_output()
@ -343,7 +353,7 @@ class ContextGraph:
return dot return dot
if __name__ == "__main__": def _test(queries, score):
contexts_str = [ contexts_str = [
"S", "S",
"HE", "HE",
@ -355,9 +365,11 @@ if __name__ == "__main__":
"THIS", "THIS",
"THEM", "THEM",
] ]
# test default score (1)
contexts = [] contexts = []
for s in contexts_str: for s in contexts_str:
contexts.append([ord(x) for x in s]) contexts.append(([ord(x) for x in s], score))
context_graph = ContextGraph(context_score=1) context_graph = ContextGraph(context_score=1)
context_graph.build(contexts) context_graph.build(contexts)
@ -369,10 +381,28 @@ if __name__ == "__main__":
context_graph.draw( context_graph.draw(
title="Graph for: " + " / ".join(contexts_str), title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf", filename=f"context_graph_{score}.pdf",
symbol_table=symbol_table, symbol_table=symbol_table,
) )
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
for q in query:
score, state = context_graph.forward_one_step(state, ord(q))
total_scores += score
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert round(total_scores, 2) == expected_score, (
total_scores,
expected_score,
query,
)
if __name__ == "__main__":
# test default score
queries = { queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
@ -384,17 +414,27 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S" "DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE" "THEN": 2, # "HE"
} }
for query, expected_score in queries.items(): _test(queries, 0)
total_scores = 0
state = context_graph.root # test custom score (5)
for q in query: # S : 5
score, state = context_graph.forward_one_step(state, ord(q)) # HE : 5 (2.5 + 2.5)
total_scores += score # SHE : 8.34 (5 + 1.67 + 1.67)
score, state = context_graph.finalize(state) # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
assert state.token == -1, state.token # HIS : 5.84 (2.5 + 1.67 + 1.67)
total_scores += score # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
assert total_scores == expected_score, ( # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
total_scores, # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
expected_score, queries = {
query, "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
) "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 24.18, # "HIS", "S", "SHE", "HE"
"SHED": 18.34, # "S", "SHE", "HE"
"SHELF": 18.34, # "S", "SHE", "HE"
"HELL": 5, # "HE"
"HELLO": 13, # "HE", "HELLO"
"DHRHISQ": 10.84, # "HIS", "S"
"THEN": 5, # "HE"
}
_test(queries, 5)