mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
update the decoding script
This commit is contained in:
parent
cfbc829df3
commit
5d41deca71
@ -348,17 +348,12 @@ def save_results(
|
|||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f,
|
f,
|
||||||
f"{test_set_name}-{key}",
|
f"{test_set_name}-{key}",
|
||||||
results_char,
|
results,
|
||||||
enable_log=enable_log,
|
enable_log=enable_log,
|
||||||
compute_CER=True,
|
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -366,13 +361,13 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
@ -391,16 +386,21 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.res_dir = params.exp_dir / params.method
|
params.res_dir = params.exp_dir / params.method
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
if params.method == "beam_search":
|
||||||
|
params.suffix += f"-beam-search-beam-size-{params.beam_size}"
|
||||||
|
|
||||||
|
params.suffix += f"-whisper-{params.model_name}"
|
||||||
setup_logger(
|
setup_logger(
|
||||||
f"{params.res_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
f"{params.res_dir}/log-{params.method}/log-decode-{params.suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
options = whisper.DecodingOptions(
|
options = whisper.DecodingOptions(
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
language="en",
|
language="en",
|
||||||
without_timestamps=True,
|
without_timestamps=True,
|
||||||
#beam_size=params.beam_size,
|
beam_size=params.beam_size if params.method == "beam_search" else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
params.decoding_options = options
|
params.decoding_options = options
|
||||||
params.cleaner = BasicTextNormalizer()
|
params.cleaner = BasicTextNormalizer()
|
||||||
params.normalizer = Normalizer()
|
params.normalizer = Normalizer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user