from local

This commit is contained in:
dohe0342 2023-02-14 15:18:18 +09:00
parent 4527fabfe2
commit a568e37484
2 changed files with 10 additions and 4 deletions

View File

@ -628,14 +628,15 @@ def save_results(
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
wer = None
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
print(key)
print(val)
wer = val
logging.info(s)
return val
@torch.no_grad()
@ -816,7 +817,12 @@ def main() -> None:
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
wer = save_results(params=params, test_set_name=test_set, results_dict=results_dict)
wer_dict[epoch] = wer
wer_dict = sorted(wer_dict.items(), key=lambda x:x[1])
for k, v in wer_dict:
print(k, v)
torch.set_num_threads(1)
# when we import add_model_arguments from train.py