From 9ef5dfe380d22fc056e5c2a693a9084ad661838a Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 10 Oct 2023 16:48:17 +0800 Subject: [PATCH] fix style --- .../zipformer_prompt_asr/decode_baseline.py | 135 +++++++++--------- 1 file changed, 65 insertions(+), 70 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py index 93530d827..6a3bab3c8 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -52,18 +52,15 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriHeavyAsrDataModule -from beam_search import ( - greedy_search, - greedy_search_batch, - modified_beam_search, -) +from beam_search import greedy_search, greedy_search_batch, modified_beam_search from ls_text_normalization import word_normalization -from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha -from train_baseline import ( - add_model_arguments, - get_params, - get_transducer_model, +from text_normalization import ( + ref_text_normalization, + remove_non_alphabetic, + upper_only_alpha, ) +from train_baseline import add_model_arguments, get_params, get_transducer_model +from utils import write_error_stats from icefall.checkpoint import ( average_checkpoints, @@ -72,13 +69,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, -) -from utils import write_error_stats +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool LOG_EPS = math.log(1e-10) @@ -242,34 +233,34 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - + parser.add_argument( "--post-normalization", type=str2bool, default=True, help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", ) - + parser.add_argument( "--long-audio-recog", type=str2bool, default=False, ) - + parser.add_argument( "--use-ls-test-set", type=str2bool, default=False, - help="Use librispeech test set for evaluation." + help="Use librispeech test set for evaluation.", ) - + parser.add_argument( "--compute-CER", type=str2bool, default=True, help="Reports CER. By default, only reports WER", ) - + add_model_arguments(parser) return parser @@ -333,14 +324,14 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - + with warnings.catch_warnings(): warnings.simplefilter("ignore") encoder_out, encoder_out_lens = model.encode_audio( - feature=feature, + feature=feature, feature_lens=feature_lens, ) - + hyps = [] if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: @@ -373,12 +364,6 @@ def decode_one_batch( encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -442,7 +427,9 @@ def decode_dataset( texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] if not params.use_ls_test_set: - book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] else: book_names = ["" for _ in cut_ids] @@ -458,10 +445,10 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): - ref_text = ref_text_normalization( - ref_text - ) + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): + ref_text = ref_text_normalization(ref_text) ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) # if not params.use_ls_test_set: @@ -496,7 +483,11 @@ def save_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words, + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + biasing_words=biasing_words, ) test_set_wers[key] = wer @@ -504,7 +495,9 @@ def save_results( if params.compute_CER: # Write CER statistics - recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) store_transcripts(filename=recog_path, texts=results, char_level=True) errs_filename = ( params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" @@ -522,9 +515,7 @@ def save_results( logging.info("Wrote detailed CER stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -539,9 +530,7 @@ def save_results( if params.compute_CER: test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tcER", file=f) for key, val in test_set_cers: @@ -569,7 +558,7 @@ def main(): "greedy_search", "modified_beam_search", ) - + if params.long_audio_recog: params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") else: @@ -579,9 +568,6 @@ def main(): params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.random_left_padding: - params.suffix += f"random-left-padding" if "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -591,7 +577,7 @@ def main(): if "ngram" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -706,7 +692,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - + # we need cut ids to display recognition results. args.return_cuts = True libriheavy = LibriHeavyAsrDataModule(args) @@ -717,26 +703,34 @@ def main(): ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() long_audio_cuts = libriheavy.long_audio_cuts() - test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,) - test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,) + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, + ) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,) - + long_audio_dl = libriheavy.valid_dataloaders( + long_audio_cuts, + ) + if params.use_ls_test_set: test_sets = ["ls-test-clean", "ls-test-other"] test_dl = [ls_test_clean_dl, ls_test_other_dl] else: test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] - + if params.long_audio_recog: test_sets = ["long-audio"] - test_dl = [long_audio_dl] - + test_dl = [long_audio_dl] + for test_set, test_dl in zip(test_sets, test_dl): if params.use_ls_test_set: - f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r') + f = open( + "data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r" + ) biasing_words = f.read().strip().split() f.close() else: @@ -747,7 +741,7 @@ def main(): model=model, sp=sp, word_table=word_table, - decoding_graph=decoding_graph, + decoding_graph=decoding_graph, ) save_results( @@ -755,40 +749,41 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) - + if params.post_normalization: if "-post-normalization" not in params.suffix: params.suffix += "-post-normalization" - + new_res = {} for k in results_dict: new_ans = [] for item in results_dict[k]: id, ref, hyp = item if params.use_ls_test_set: - hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens hyp = upper_only_alpha(" ".join(hyp)).split() hyp = [word_normalization(w.upper()) for w in hyp] hyp = " ".join(hyp).split() hyp = [w for w in hyp if w != ""] - ref = upper_only_alpha(" ".join(ref)).split() + ref = upper_only_alpha(" ".join(ref)).split() else: hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_ans.append((id,ref,hyp)) - new_res[k] = new_ans - + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) + new_res[k] = new_ans + save_results( params=params, test_set_name=test_set, results_dict=new_res, biasing_words=biasing_words, ) - + if params.suffix.endswith("-post-normalization"): params.suffix = params.suffix.replace("-post-normalization", "") - logging.info("Done!")