mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Minor fixes
This commit is contained in:
parent
7d91e8b6d5
commit
80903858a2
@ -211,7 +211,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
kws_graph: ContextGraph,
|
||||
keywords_graph: ContextGraph,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -272,7 +272,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
keywords_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
beam=params.beam_size,
|
||||
num_tailing_blanks=8,
|
||||
)
|
||||
@ -297,7 +297,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
kws_graph: ContextGraph,
|
||||
keywords_graph: ContextGraph,
|
||||
keywords: Set[str],
|
||||
test_only_keywords: bool,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
@ -343,7 +343,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
kws_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -561,10 +561,10 @@ def main():
|
||||
keywords_thresholds.append(threshold)
|
||||
params.keywords_config = "".join(keywords_config)
|
||||
|
||||
kws_graph = ContextGraph(
|
||||
keywords_graph = ContextGraph(
|
||||
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
||||
)
|
||||
kws_graph.build(
|
||||
keywords_graph.build(
|
||||
token_ids=token_ids,
|
||||
phrases=phrases,
|
||||
scores=keywords_scores,
|
||||
@ -697,8 +697,8 @@ def main():
|
||||
test_sets = []
|
||||
test_dls = []
|
||||
if params.test_set == "large":
|
||||
test_sets.append("cn_commands_large")
|
||||
test_dls.append(cn_commands_large_dl)
|
||||
test_sets += ["cn_commands_large", "test_net"]
|
||||
test_dls += [cn_commands_large_dl, test_net_dl]
|
||||
else:
|
||||
assert params.test_set == "small", params.test_set
|
||||
test_sets += [
|
||||
@ -722,7 +722,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
kws_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
keywords=keywords,
|
||||
test_only_keywords="test_net" not in test_set,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user