update the decoding script

This commit is contained in:
marcoyang 2024-03-28 18:16:52 +08:00
parent cfbc829df3
commit 5d41deca71

View File

@ -348,17 +348,12 @@ def save_results(
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results_char,
results,
enable_log=enable_log,
compute_CER=True,
)
test_set_wers[key] = wer
@ -366,13 +361,13 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
@ -391,16 +386,21 @@ def main():
params.update(vars(args))
params.res_dir = params.exp_dir / params.method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.method == "beam_search":
params.suffix += f"-beam-search-beam-size-{params.beam_size}"
params.suffix += f"-whisper-{params.model_name}"
setup_logger(
f"{params.res_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
f"{params.res_dir}/log-{params.method}/log-decode-{params.suffix}"
)
options = whisper.DecodingOptions(
task="transcribe",
language="en",
without_timestamps=True,
#beam_size=params.beam_size,
beam_size=params.beam_size if params.method == "beam_search" else None,
)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()