mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 01:22:22 +00:00
Add cumstomized score for hotwords (#1385)
* add custom score for each hotword * Add more comments * Fix deocde * fix style * minor fixes
This commit is contained in:
parent
666d69b20d
commit
11d816d174
@ -641,7 +641,7 @@ def main():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
context_graph.build([(c, 0.0) for c in contexts])
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
@ -686,7 +686,7 @@ def main():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
context_graph.build([(c, 0.0) for c in contexts])
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
@ -927,9 +927,9 @@ def main():
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
contexts.append((sp.encode(line.strip()), 0.0))
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
@ -1001,9 +1001,9 @@ def main():
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
contexts.append((sp.encode(line.strip()), 0.0))
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
@ -868,7 +868,7 @@ def main():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
context_graph.build([(c, 0.0) for c in contexts])
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
|
@ -84,6 +84,9 @@ class ContextGraph:
|
||||
context_score:
|
||||
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||
word/phrase will have larger bonus score, they have to be matched though).
|
||||
Note: This is just the default score for each token, the users can manually
|
||||
specify the context_score for each word/phrase (i.e. different phrase might
|
||||
have different token score).
|
||||
"""
|
||||
self.context_score = context_score
|
||||
self.num_nodes = 0
|
||||
@ -133,7 +136,7 @@ class ContextGraph:
|
||||
node.output_score += 0 if output is None else output.output_score
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
def build(self, token_ids: List[Tuple[List[int], float]]):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
for each trie node.
|
||||
@ -142,26 +145,46 @@ class ContextGraph:
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
The given token lists to build the ContextGraph, it is a list of token list,
|
||||
each token list contains the token ids for a word/phrase. The token id
|
||||
could be an id of a char (modeling with single Chinese char) or an id
|
||||
of a BPE (modeling with BPEs).
|
||||
The given token lists to build the ContextGraph, it is a list of tuple of
|
||||
token list and its customized score, the token list contains the token ids
|
||||
for a word/phrase. The token id could be an id of a char
|
||||
(modeling with single Chinese char) or an id of a BPE
|
||||
(modeling with BPEs). The score is the total score for current token list,
|
||||
0 means using the default value (i.e. self.context_score).
|
||||
|
||||
Note: The phrases would have shared states, the score of the shared states is
|
||||
the maximum value among all the tokens sharing this state.
|
||||
"""
|
||||
for tokens in token_ids:
|
||||
for (tokens, score) in token_ids:
|
||||
node = self.root
|
||||
# If has customized score using the customized token score, otherwise
|
||||
# using the default score
|
||||
context_score = (
|
||||
self.context_score if score == 0.0 else round(score / len(tokens), 2)
|
||||
)
|
||||
for i, token in enumerate(tokens):
|
||||
node_next = {}
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
node_id = self.num_nodes
|
||||
token_score = context_score
|
||||
is_end = i == len(tokens) - 1
|
||||
node_score = node.node_score + self.context_score
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=self.context_score,
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
)
|
||||
else:
|
||||
# node exists, get the score of shared state.
|
||||
token_score = max(context_score, node.next[token].token_score)
|
||||
node_id = node.next[token].id
|
||||
node_next = node.next[token].next
|
||||
is_end = i == len(tokens) - 1 or node.next[token].is_end
|
||||
node_score = node.node_score + token_score
|
||||
node.next[token] = ContextState(
|
||||
id=node_id,
|
||||
token=token,
|
||||
token_score=token_score,
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
)
|
||||
node.next[token].next = node_next
|
||||
node = node.next[token]
|
||||
self._fill_fail_output()
|
||||
|
||||
@ -343,7 +366,7 @@ class ContextGraph:
|
||||
return dot
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def _test(queries, score):
|
||||
contexts_str = [
|
||||
"S",
|
||||
"HE",
|
||||
@ -355,9 +378,11 @@ if __name__ == "__main__":
|
||||
"THIS",
|
||||
"THEM",
|
||||
]
|
||||
|
||||
# test default score (1)
|
||||
contexts = []
|
||||
for s in contexts_str:
|
||||
contexts.append([ord(x) for x in s])
|
||||
contexts.append(([ord(x) for x in s], score))
|
||||
|
||||
context_graph = ContextGraph(context_score=1)
|
||||
context_graph.build(contexts)
|
||||
@ -369,10 +394,28 @@ if __name__ == "__main__":
|
||||
|
||||
context_graph.draw(
|
||||
title="Graph for: " + " / ".join(contexts_str),
|
||||
filename="context_graph.pdf",
|
||||
filename=f"context_graph_{score}.pdf",
|
||||
symbol_table=symbol_table,
|
||||
)
|
||||
|
||||
for query, expected_score in queries.items():
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert round(total_scores, 2) == expected_score, (
|
||||
total_scores,
|
||||
expected_score,
|
||||
query,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test default score
|
||||
queries = {
|
||||
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
@ -384,17 +427,27 @@ if __name__ == "__main__":
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
"THEN": 2, # "HE"
|
||||
}
|
||||
for query, expected_score in queries.items():
|
||||
total_scores = 0
|
||||
state = context_graph.root
|
||||
for q in query:
|
||||
score, state = context_graph.forward_one_step(state, ord(q))
|
||||
total_scores += score
|
||||
score, state = context_graph.finalize(state)
|
||||
assert state.token == -1, state.token
|
||||
total_scores += score
|
||||
assert total_scores == expected_score, (
|
||||
total_scores,
|
||||
expected_score,
|
||||
query,
|
||||
)
|
||||
_test(queries, 0)
|
||||
|
||||
# test custom score (5)
|
||||
# S : 5
|
||||
# HE : 5 (2.5 + 2.5)
|
||||
# SHE : 8.34 (5 + 1.67 + 1.67)
|
||||
# SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
|
||||
# HIS : 5.84 (2.5 + 1.67 + 1.67)
|
||||
# HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
|
||||
# HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
|
||||
# THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
|
||||
queries = {
|
||||
"HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||
"HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 24.18, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 18.34, # "S", "SHE", "HE"
|
||||
"SHELF": 18.34, # "S", "SHE", "HE"
|
||||
"HELL": 5, # "HE"
|
||||
"HELLO": 13, # "HE", "HELLO"
|
||||
"DHRHISQ": 10.84, # "HIS", "S"
|
||||
"THEN": 5, # "HE"
|
||||
}
|
||||
|
||||
_test(queries, 5)
|
||||
|
Loading…
x
Reference in New Issue
Block a user