from local

This commit is contained in:
dohe0342 2023-02-14 15:18:18 +09:00
parent 4527fabfe2
commit a568e37484
2 changed files with 10 additions and 4 deletions

View File

@ -628,14 +628,15 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
wer = None
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note) s += "{}\t{}{}\n".format(key, val, note)
note = "" note = ""
print(key) wer = val
print(val)
logging.info(s) logging.info(s)
return val
@torch.no_grad() @torch.no_grad()
@ -816,7 +817,12 @@ def main() -> None:
eos_id=eos_id, eos_id=eos_id,
) )
save_results(params=params, test_set_name=test_set, results_dict=results_dict) wer = save_results(params=params, test_set_name=test_set, results_dict=results_dict)
wer_dict[epoch] = wer
wer_dict = sorted(wer_dict.items(), key=lambda x:x[1])
for k, v in wer_dict:
print(k, v)
torch.set_num_threads(1) torch.set_num_threads(1)
# when we import add_model_arguments from train.py # when we import add_model_arguments from train.py