Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-12-15 08:43:38 +08:00
parent e38f04e70f
commit cbda811a10

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -280,13 +281,17 @@ def save_results(
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
@ -296,7 +301,10 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -319,8 +327,10 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.res_dir = params.exp_dir / "greedy_search"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-decode") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
device = torch.device("cpu") device = torch.device("cpu")