mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add cumstomized score for hotwords (#1385)
* add custom score for each hotword * Add more comments * Fix deocde * fix style * minor fixes
This commit is contained in:
parent
666d69b20d
commit
11d816d174
@ -641,7 +641,7 @@ def main():
|
|||||||
contexts_text.append(line.strip())
|
contexts_text.append(line.strip())
|
||||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(contexts)
|
context_graph.build([(c, 0.0) for c in contexts])
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -686,7 +686,7 @@ def main():
|
|||||||
contexts_text.append(line.strip())
|
contexts_text.append(line.strip())
|
||||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(contexts)
|
context_graph.build([(c, 0.0) for c in contexts])
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -927,9 +927,9 @@ def main():
|
|||||||
if os.path.exists(params.context_file):
|
if os.path.exists(params.context_file):
|
||||||
contexts = []
|
contexts = []
|
||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append((sp.encode(line.strip()), 0.0))
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(sp.encode(contexts))
|
context_graph.build(contexts)
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -1001,9 +1001,9 @@ def main():
|
|||||||
if os.path.exists(params.context_file):
|
if os.path.exists(params.context_file):
|
||||||
contexts = []
|
contexts = []
|
||||||
for line in open(params.context_file).readlines():
|
for line in open(params.context_file).readlines():
|
||||||
contexts.append(line.strip())
|
contexts.append((sp.encode(line.strip()), 0.0))
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(sp.encode(contexts))
|
context_graph.build(contexts)
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -868,7 +868,7 @@ def main():
|
|||||||
contexts_text.append(line.strip())
|
contexts_text.append(line.strip())
|
||||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(contexts)
|
context_graph.build([(c, 0.0) for c in contexts])
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -84,6 +84,9 @@ class ContextGraph:
|
|||||||
context_score:
|
context_score:
|
||||||
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
The bonus score for each token(note: NOT for each word/phrase, it means longer
|
||||||
word/phrase will have larger bonus score, they have to be matched though).
|
word/phrase will have larger bonus score, they have to be matched though).
|
||||||
|
Note: This is just the default score for each token, the users can manually
|
||||||
|
specify the context_score for each word/phrase (i.e. different phrase might
|
||||||
|
have different token score).
|
||||||
"""
|
"""
|
||||||
self.context_score = context_score
|
self.context_score = context_score
|
||||||
self.num_nodes = 0
|
self.num_nodes = 0
|
||||||
@ -133,7 +136,7 @@ class ContextGraph:
|
|||||||
node.output_score += 0 if output is None else output.output_score
|
node.output_score += 0 if output is None else output.output_score
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
|
|
||||||
def build(self, token_ids: List[List[int]]):
|
def build(self, token_ids: List[Tuple[List[int], float]]):
|
||||||
"""Build the ContextGraph from a list of token list.
|
"""Build the ContextGraph from a list of token list.
|
||||||
It first build a trie from the given token lists, then fill the fail arc
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
for each trie node.
|
for each trie node.
|
||||||
@ -142,26 +145,46 @@ class ContextGraph:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids:
|
token_ids:
|
||||||
The given token lists to build the ContextGraph, it is a list of token list,
|
The given token lists to build the ContextGraph, it is a list of tuple of
|
||||||
each token list contains the token ids for a word/phrase. The token id
|
token list and its customized score, the token list contains the token ids
|
||||||
could be an id of a char (modeling with single Chinese char) or an id
|
for a word/phrase. The token id could be an id of a char
|
||||||
of a BPE (modeling with BPEs).
|
(modeling with single Chinese char) or an id of a BPE
|
||||||
|
(modeling with BPEs). The score is the total score for current token list,
|
||||||
|
0 means using the default value (i.e. self.context_score).
|
||||||
|
|
||||||
|
Note: The phrases would have shared states, the score of the shared states is
|
||||||
|
the maximum value among all the tokens sharing this state.
|
||||||
"""
|
"""
|
||||||
for tokens in token_ids:
|
for (tokens, score) in token_ids:
|
||||||
node = self.root
|
node = self.root
|
||||||
|
# If has customized score using the customized token score, otherwise
|
||||||
|
# using the default score
|
||||||
|
context_score = (
|
||||||
|
self.context_score if score == 0.0 else round(score / len(tokens), 2)
|
||||||
|
)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
|
node_next = {}
|
||||||
if token not in node.next:
|
if token not in node.next:
|
||||||
self.num_nodes += 1
|
self.num_nodes += 1
|
||||||
|
node_id = self.num_nodes
|
||||||
|
token_score = context_score
|
||||||
is_end = i == len(tokens) - 1
|
is_end = i == len(tokens) - 1
|
||||||
node_score = node.node_score + self.context_score
|
else:
|
||||||
node.next[token] = ContextState(
|
# node exists, get the score of shared state.
|
||||||
id=self.num_nodes,
|
token_score = max(context_score, node.next[token].token_score)
|
||||||
token=token,
|
node_id = node.next[token].id
|
||||||
token_score=self.context_score,
|
node_next = node.next[token].next
|
||||||
node_score=node_score,
|
is_end = i == len(tokens) - 1 or node.next[token].is_end
|
||||||
output_score=node_score if is_end else 0,
|
node_score = node.node_score + token_score
|
||||||
is_end=is_end,
|
node.next[token] = ContextState(
|
||||||
)
|
id=node_id,
|
||||||
|
token=token,
|
||||||
|
token_score=token_score,
|
||||||
|
node_score=node_score,
|
||||||
|
output_score=node_score if is_end else 0,
|
||||||
|
is_end=is_end,
|
||||||
|
)
|
||||||
|
node.next[token].next = node_next
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
self._fill_fail_output()
|
self._fill_fail_output()
|
||||||
|
|
||||||
@ -343,7 +366,7 @@ class ContextGraph:
|
|||||||
return dot
|
return dot
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def _test(queries, score):
|
||||||
contexts_str = [
|
contexts_str = [
|
||||||
"S",
|
"S",
|
||||||
"HE",
|
"HE",
|
||||||
@ -355,9 +378,11 @@ if __name__ == "__main__":
|
|||||||
"THIS",
|
"THIS",
|
||||||
"THEM",
|
"THEM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# test default score (1)
|
||||||
contexts = []
|
contexts = []
|
||||||
for s in contexts_str:
|
for s in contexts_str:
|
||||||
contexts.append([ord(x) for x in s])
|
contexts.append(([ord(x) for x in s], score))
|
||||||
|
|
||||||
context_graph = ContextGraph(context_score=1)
|
context_graph = ContextGraph(context_score=1)
|
||||||
context_graph.build(contexts)
|
context_graph.build(contexts)
|
||||||
@ -369,10 +394,28 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
context_graph.draw(
|
context_graph.draw(
|
||||||
title="Graph for: " + " / ".join(contexts_str),
|
title="Graph for: " + " / ".join(contexts_str),
|
||||||
filename="context_graph.pdf",
|
filename=f"context_graph_{score}.pdf",
|
||||||
symbol_table=symbol_table,
|
symbol_table=symbol_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for query, expected_score in queries.items():
|
||||||
|
total_scores = 0
|
||||||
|
state = context_graph.root
|
||||||
|
for q in query:
|
||||||
|
score, state = context_graph.forward_one_step(state, ord(q))
|
||||||
|
total_scores += score
|
||||||
|
score, state = context_graph.finalize(state)
|
||||||
|
assert state.token == -1, state.token
|
||||||
|
total_scores += score
|
||||||
|
assert round(total_scores, 2) == expected_score, (
|
||||||
|
total_scores,
|
||||||
|
expected_score,
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test default score
|
||||||
queries = {
|
queries = {
|
||||||
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||||
@ -384,17 +427,27 @@ if __name__ == "__main__":
|
|||||||
"DHRHISQ": 4, # "HIS", "S"
|
"DHRHISQ": 4, # "HIS", "S"
|
||||||
"THEN": 2, # "HE"
|
"THEN": 2, # "HE"
|
||||||
}
|
}
|
||||||
for query, expected_score in queries.items():
|
_test(queries, 0)
|
||||||
total_scores = 0
|
|
||||||
state = context_graph.root
|
# test custom score (5)
|
||||||
for q in query:
|
# S : 5
|
||||||
score, state = context_graph.forward_one_step(state, ord(q))
|
# HE : 5 (2.5 + 2.5)
|
||||||
total_scores += score
|
# SHE : 8.34 (5 + 1.67 + 1.67)
|
||||||
score, state = context_graph.finalize(state)
|
# SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
|
||||||
assert state.token == -1, state.token
|
# HIS : 5.84 (2.5 + 1.67 + 1.67)
|
||||||
total_scores += score
|
# HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
|
||||||
assert total_scores == expected_score, (
|
# HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
|
||||||
total_scores,
|
# THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
|
||||||
expected_score,
|
queries = {
|
||||||
query,
|
"HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
|
||||||
)
|
"HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
|
||||||
|
"HISHE": 24.18, # "HIS", "S", "SHE", "HE"
|
||||||
|
"SHED": 18.34, # "S", "SHE", "HE"
|
||||||
|
"SHELF": 18.34, # "S", "SHE", "HE"
|
||||||
|
"HELL": 5, # "HE"
|
||||||
|
"HELLO": 13, # "HE", "HELLO"
|
||||||
|
"DHRHISQ": 10.84, # "HIS", "S"
|
||||||
|
"THEN": 5, # "HE"
|
||||||
|
}
|
||||||
|
|
||||||
|
_test(queries, 5)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user