mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Minor fixes
This commit is contained in:
parent
7d91e8b6d5
commit
80903858a2
@ -211,7 +211,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
kws_graph: ContextGraph,
|
keywords_graph: ContextGraph,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -272,7 +272,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
keywords_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
num_tailing_blanks=8,
|
num_tailing_blanks=8,
|
||||||
)
|
)
|
||||||
@ -297,7 +297,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
kws_graph: ContextGraph,
|
keywords_graph: ContextGraph,
|
||||||
keywords: Set[str],
|
keywords: Set[str],
|
||||||
test_only_keywords: bool,
|
test_only_keywords: bool,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
@ -343,7 +343,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
kws_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -561,10 +561,10 @@ def main():
|
|||||||
keywords_thresholds.append(threshold)
|
keywords_thresholds.append(threshold)
|
||||||
params.keywords_config = "".join(keywords_config)
|
params.keywords_config = "".join(keywords_config)
|
||||||
|
|
||||||
kws_graph = ContextGraph(
|
keywords_graph = ContextGraph(
|
||||||
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
||||||
)
|
)
|
||||||
kws_graph.build(
|
keywords_graph.build(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
phrases=phrases,
|
phrases=phrases,
|
||||||
scores=keywords_scores,
|
scores=keywords_scores,
|
||||||
@ -697,8 +697,8 @@ def main():
|
|||||||
test_sets = []
|
test_sets = []
|
||||||
test_dls = []
|
test_dls = []
|
||||||
if params.test_set == "large":
|
if params.test_set == "large":
|
||||||
test_sets.append("cn_commands_large")
|
test_sets += ["cn_commands_large", "test_net"]
|
||||||
test_dls.append(cn_commands_large_dl)
|
test_dls += [cn_commands_large_dl, test_net_dl]
|
||||||
else:
|
else:
|
||||||
assert params.test_set == "small", params.test_set
|
assert params.test_set == "small", params.test_set
|
||||||
test_sets += [
|
test_sets += [
|
||||||
@ -722,7 +722,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
kws_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
keywords=keywords,
|
keywords=keywords,
|
||||||
test_only_keywords="test_net" not in test_set,
|
test_only_keywords="test_net" not in test_set,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user