mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add CER numbers
This commit is contained in:
parent
4fc1638959
commit
ed30271715
@ -23,11 +23,10 @@ ArXiv link: https://arxiv.org/abs/2104.02014
|
||||
|
||||
## Performance Record
|
||||
|
||||
| Decoding method | val |
|
||||
|---------------------------|------------|
|
||||
| greedy search | 2.40 |
|
||||
| beam search | 2.24 |
|
||||
| modified beam search | 2.24 |
|
||||
| fast beam search | 2.35 |
|
||||
| Decoding method | val WER | val CER |
|
||||
|---------------------------|------------|---------|
|
||||
| greedy search | 2.40 | 0.99 |
|
||||
| modified beam search | 2.24 | 0.91 |
|
||||
| fast beam search | 2.35 | 0.97 |
|
||||
|
||||
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
||||
|
@ -15,7 +15,6 @@ The WERs are
|
||||
| | dev | val | comment |
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
||||
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
||||
|
||||
@ -49,14 +48,6 @@ The decoding command is:
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
# beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--iter 696000 --avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
# modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--iter 696000 --avg 10 \
|
||||
|
@ -69,7 +69,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import SPGISpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -252,7 +252,7 @@ def decode_one_batch(
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -400,6 +400,7 @@ def save_results(
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
@ -409,31 +410,56 @@ def save_results(
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
wers_filename = (
|
||||
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
with open(wers_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
# we also compute CER for spgispeech dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
cers_filename = (
|
||||
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(cers_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
logging.info("Wrote detailed error stats to {}".format(wers_filename))
|
||||
|
||||
test_set_wers = {
|
||||
k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
}
|
||||
test_set_cers = {
|
||||
k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
}
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
print("settings\tWER\tCER", file=f)
|
||||
for key in test_set_wers:
|
||||
print(
|
||||
"{}\t{}\t{}".format(
|
||||
key, test_set_wers[key], test_set_cers[key]
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
for key in test_set_wers:
|
||||
s += "{}\t{}\t{}{}\n".format(
|
||||
key, test_set_wers[key], test_set_cers[key], note
|
||||
)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user