Add more comments

This commit is contained in:
pkufool 2023-11-16 11:30:50 +08:00
parent 8c5f5795d4
commit cba104a29d

View File

@ -84,6 +84,9 @@ class ContextGraph:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
Note: This is just the default score for each token, the users can manually
specify the context_score for each word/phrase (i.e. different phrase might
have different token score).
"""
self.context_score = context_score
self.num_nodes = 0
@ -142,13 +145,20 @@ class ContextGraph:
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
The given token lists to build the ContextGraph, it is a list of tuple of
token list and its customized score, the token list contains the token ids
for a word/phrase. The token id could be an id of a char
(modeling with single Chinese char) or an id of a BPE
(modeling with BPEs). The score is the total score for current token list,
0 means using the default value (i.e. self.context_score).
Note: The phrases would have shared states, the score of the shared states is
the maximum value among all the tokens sharing this state.
"""
for (tokens, score) in token_ids:
node = self.root
# If has customized score using the customized token score, otherwise
# using the default score
context_score = self.context_score if score == 0.0 else round(score / len(tokens), 2)
for i, token in enumerate(tokens):
node_next = {}
@ -158,6 +168,7 @@ class ContextGraph:
token_score = context_score
is_end = i == len(tokens) - 1
else:
# node exists, get the score of shared state.
token_score = max(context_score, node.next[token].token_score)
node_id = node.next[token].id
node_next = node.next[token].next