From 26c022665b2f0f824e6a23c72dd38b239f2a67c6 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Tue, 29 Apr 2025 10:32:12 -0700 Subject: [PATCH] remove space if existing fix --- .../ASR_LLM/whisper_llm_zh/decode.py | 32 ++++++++----------- .../ASR_LLM/whisper_llm_zh/train.py | 1 + 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 91af6ffec..e5e48d09c 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -382,9 +382,7 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - print(f"ref: {''.join(ref_text)}") - print(f"hyp: {''.join(hyp_words)}") + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): this_batch.append((cut_id, ref_text, hyp_words)) results[lm_scale].extend(this_batch) @@ -401,40 +399,38 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - - enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recog_path, texts=results, char_level=True) + logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned + # The following prints out CERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index acf6daca7..7947a60a5 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -339,6 +339,7 @@ def compute_loss( messages = [] for i, text in enumerate(texts): + text = text.replace(" ", "") message = [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text},