Minor fixes

This commit is contained in:
pkufool 2023-06-15 10:21:12 +08:00
parent a1b12cf4e9
commit bf36d1984e

View File

@ -515,6 +515,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -559,6 +560,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -692,6 +694,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -780,7 +787,6 @@ def main():
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
# word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
@ -788,11 +794,9 @@ def main():
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
# word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
# word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -832,7 +836,7 @@ def main():
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
# word_table=word_table, graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )