Minor fixes

This commit is contained in:
pkufool 2024-02-19 14:34:25 +08:00
parent 7d91e8b6d5
commit 80903858a2

View File

@ -211,7 +211,7 @@ def decode_one_batch(
model: nn.Module,
lexicon: Lexicon,
batch: dict,
kws_graph: ContextGraph,
keywords_graph: ContextGraph,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -272,7 +272,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=kws_graph,
keywords_graph=keywords_graph,
beam=params.beam_size,
num_tailing_blanks=8,
)
@ -297,7 +297,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
kws_graph: ContextGraph,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
@ -343,7 +343,7 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
batch=batch,
)
@ -561,10 +561,10 @@ def main():
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
kws_graph = ContextGraph(
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
kws_graph.build(
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
@ -697,8 +697,8 @@ def main():
test_sets = []
test_dls = []
if params.test_set == "large":
test_sets.append("cn_commands_large")
test_dls.append(cn_commands_large_dl)
test_sets += ["cn_commands_large", "test_net"]
test_dls += [cn_commands_large_dl, test_net_dl]
else:
assert params.test_set == "small", params.test_set
test_sets += [
@ -722,7 +722,7 @@ def main():
params=params,
model=model,
lexicon=lexicon,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="test_net" not in test_set,
)