From 5d41deca71198ad3a15104bf52bcd3258e130581 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 28 Mar 2024 18:16:52 +0800 Subject: [PATCH] update the decoding script --- egs/librispeech/ASR/whisper/decode.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/whisper/decode.py b/egs/librispeech/ASR/whisper/decode.py index 83d33418d..c5f8a9406 100755 --- a/egs/librispeech/ASR/whisper/decode.py +++ b/egs/librispeech/ASR/whisper/decode.py @@ -348,17 +348,12 @@ def save_results( errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", - results_char, + results, enable_log=enable_log, - compute_CER=True, ) test_set_wers[key] = wer @@ -366,13 +361,13 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: - print("settings\tCER", file=f) + print("settings\tWER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) @@ -391,16 +386,21 @@ def main(): params.update(vars(args)) params.res_dir = params.exp_dir / params.method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.method == "beam_search": + params.suffix += f"-beam-search-beam-size-{params.beam_size}" + + params.suffix += f"-whisper-{params.model_name}" setup_logger( - f"{params.res_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + f"{params.res_dir}/log-{params.method}/log-decode-{params.suffix}" ) options = whisper.DecodingOptions( task="transcribe", language="en", without_timestamps=True, - #beam_size=params.beam_size, + beam_size=params.beam_size if params.method == "beam_search" else None, ) + params.decoding_options = options params.cleaner = BasicTextNormalizer() params.normalizer = Normalizer()