Minor fixes

This commit is contained in:
pkufool 2023-06-15 10:21:12 +08:00
parent a1b12cf4e9
commit bf36d1984e

View File

@ -515,6 +515,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -559,6 +560,7 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
@ -692,6 +694,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
@ -780,7 +787,6 @@ def main():
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
# word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
@ -788,11 +794,9 @@ def main():
)
decoding_graph.scores *= params.ngram_lm_scale
else:
# word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
# word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -832,7 +836,7 @@ def main():
params=params,
model=model,
lexicon=lexicon,
# word_table=word_table,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)