This commit is contained in:
Guo Liyong 2021-12-02 17:46:14 +08:00
parent a4722dd7c0
commit 54bcc167e1
2 changed files with 7 additions and 4 deletions

View File

@ -501,7 +501,8 @@ def save_results(
for key, results in results_dict.items():
recog_path = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-recogs-{test_set_name}-{key}.txt"
/ f"epoch-{params.epoch}-avg-{params.avg}- \
recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
@ -511,7 +512,8 @@ def save_results(
# ref/hyp pairs.
errs_filename = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-errs-{test_set_name}-{key}.txt"
/ f"epoch-{params.epoch}-avg-{params.avg}- \
errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
@ -527,7 +529,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-wer-summary-{test_set_name}.txt"
/ f"epoch-{params.epoch}-avg-{params.avg}- \
wer-summary-{test_set_name}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

View File

@ -439,7 +439,7 @@ def compute_loss(
info["att_loss"] = att_loss.detach().cpu().item()
if params.codebook_weight != 0.0:
info["codebook_loss"] = cdidx_loss.detach().cpu().item()
info["codebook_loss"] = codebook_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()