From 04c3f9ab53b66167194a5f35fae3497a1433042b Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 25 May 2023 16:33:10 +0800 Subject: [PATCH] Minor fixes --- .../beam_search.py | 4 +-- .../pruned_transducer_stateless7/decode.py | 32 ++++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 70df7bc08..40efc98ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -608,7 +608,7 @@ def greedy_search( # logits is (1, 1, 1, vocab_size) if blank_penalty != 0: - logits[:,:,:,0] -= blank_penalty + logits[:, :, :, 0] -= blank_penalty y = logits.argmax().item() if y not in (blank_id, unk_id): @@ -1748,7 +1748,7 @@ def beam_search( ) if blank_penalty != 0: - logits[:,:,:,0] -= blank_penalty + logits[:, :, :, 0] -= blank_penalty # TODO(fangjun): Scale the blank posterior log_prob = (logits / temperature).log_softmax(dim=-1) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py index e3931509b..1363b0ab7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py @@ -123,9 +123,9 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from lhotse.cut import Cut from train import add_model_arguments, get_params, get_transducer_model -from lhotse.cut import Cut from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, @@ -324,7 +324,7 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, lexicon: Lexicon, - # graph_compiler: CharCtcTrainingGraphCompiler, + graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: @@ -431,7 +431,10 @@ def decode_one_batch( hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty, + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -461,7 +464,9 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size, + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, blank_penalty=params.blank_penalty, ) else: @@ -493,6 +498,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -538,6 +544,7 @@ def decode_dataset( model=model, lexicon=lexicon, decoding_graph=decoding_graph, + graph_compiler=graph_compiler, batch=batch, ) @@ -660,6 +667,11 @@ def main(): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + logging.info(params) logging.info("About to create model") @@ -747,8 +759,6 @@ def main(): if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - # word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( @@ -756,11 +766,9 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - # word_table = None decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None - # word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -791,8 +799,6 @@ def main(): test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] - # test_sets = ["TEST_MEETING"] - # test_dl = [test_meeting_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -800,12 +806,14 @@ def main(): params=params, model=model, lexicon=lexicon, - # word_table=word_table, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, ) save_results( - params=params, test_set_name=test_set, results_dict=results_dict, + params=params, + test_set_name=test_set, + results_dict=results_dict, ) logging.info("Done!")