add CER numbers

This commit is contained in:
Desh Raj 2022-05-15 08:08:49 -04:00
parent 4fc1638959
commit ed30271715
3 changed files with 44 additions and 28 deletions

View File

@ -23,11 +23,10 @@ ArXiv link: https://arxiv.org/abs/2104.02014
## Performance Record
| Decoding method | val |
|---------------------------|------------|
| greedy search | 2.40 |
| beam search | 2.24 |
| modified beam search | 2.24 |
| fast beam search | 2.35 |
| Decoding method | val WER | val CER |
|---------------------------|------------|---------|
| greedy search | 2.40 | 0.99 |
| modified beam search | 2.24 | 0.91 |
| fast beam search | 2.35 | 0.97 |
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.

View File

@ -15,7 +15,6 @@ The WERs are
| | dev | val | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
@ -49,14 +48,6 @@ The decoding command is:
--max-duration 100 \
--decoding-method greedy_search
# beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
# modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \

View File

@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import SPGISpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -252,7 +252,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -400,6 +400,7 @@ def save_results(
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
@ -409,31 +410,56 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
wers_filename = (
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
with open(wers_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
# we also compute CER for spgispeech dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
cers_filename = (
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(cers_filename, "w") as f:
cer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_cers[key] = cer
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
logging.info("Wrote detailed error stats to {}".format(wers_filename))
test_set_wers = {
k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
}
test_set_cers = {
k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
}
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
print("settings\tWER\tCER", file=f)
for key in test_set_wers:
print(
"{}\t{}\t{}".format(
key, test_set_wers[key], test_set_cers[key]
),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
for key in test_set_wers:
s += "{}\t{}\t{}{}\n".format(
key, test_set_wers[key], test_set_cers[key], note
)
note = ""
logging.info(s)