Minor fixes

This commit is contained in:
pkufool 2023-05-25 16:33:10 +08:00
parent 899f858659
commit 04c3f9ab53
2 changed files with 22 additions and 14 deletions

View File

@ -123,9 +123,9 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from lhotse.cut import Cut
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -324,7 +324,7 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
# graph_compiler: CharCtcTrainingGraphCompiler, graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
@ -431,7 +431,10 @@ def decode_one_batch(
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty, model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -461,7 +464,9 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size, model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty, blank_penalty=params.blank_penalty,
) )
else: else:
@ -493,6 +498,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -538,6 +544,7 @@ def decode_dataset(
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
graph_compiler=graph_compiler,
batch=batch, batch=batch,
) )
@ -660,6 +667,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -747,8 +759,6 @@ def main():
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
# word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt" lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}") logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict( decoding_graph = k2.Fsa.from_dict(
@ -756,11 +766,9 @@ def main():
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
# word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else: else:
decoding_graph = None decoding_graph = None
# word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -791,8 +799,6 @@ def main():
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]
# test_sets = ["TEST_MEETING"]
# test_dl = [test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
@ -800,12 +806,14 @@ def main():
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
# word_table=word_table, graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(
params=params, test_set_name=test_set, results_dict=results_dict, params=params,
test_set_name=test_set,
results_dict=results_dict,
) )
logging.info("Done!") logging.info("Done!")