From e8b42bab73026f5a396815884bc51fb043c601c6 Mon Sep 17 00:00:00 2001 From: wangtiance Date: Thu, 26 Oct 2023 19:21:23 +0800 Subject: [PATCH] black format --- .../ASR/tiny_transducer_ctc/decode.py | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index 74aae3ad3..6c2bf9ea1 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -39,6 +39,7 @@ from icefall.utils import ( LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -168,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -235,7 +235,7 @@ def decode_one_batch( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_LG, + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -252,9 +252,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -321,10 +319,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -455,9 +450,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -469,8 +462,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / - f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -490,10 +482,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -550,9 +539,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -579,8 +566,8 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() else: - params.blank_id = lexicon.token_table.get('') - params.unk_id = lexicon.token_table.get('SPN') + params.blank_id = lexicon.token_table.get("") + params.unk_id = lexicon.token_table.get("SPN") params.vocab_size = max(lexicon.tokens) + 1 sp = None @@ -591,9 +578,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -606,8 +593,7 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -618,13 +604,12 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -682,9 +667,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -695,7 +678,8 @@ def main(): join_param = sum([p.numel() for p in model.joiner.parameters()]) logging.info(f"Number of model parameters: {num_param}") logging.info( - f"Parameters for transducer decoding: {enc_param + dec_param + join_param}") + f"Parameters for transducer decoding: {enc_param + dec_param + join_param}" + ) # we need cut ids to display recognition results. args.return_cuts = True