mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
add CER numbers
This commit is contained in:
parent
4fc1638959
commit
ed30271715
@ -23,11 +23,10 @@ ArXiv link: https://arxiv.org/abs/2104.02014
|
|||||||
|
|
||||||
## Performance Record
|
## Performance Record
|
||||||
|
|
||||||
| Decoding method | val |
|
| Decoding method | val WER | val CER |
|
||||||
|---------------------------|------------|
|
|---------------------------|------------|---------|
|
||||||
| greedy search | 2.40 |
|
| greedy search | 2.40 | 0.99 |
|
||||||
| beam search | 2.24 |
|
| modified beam search | 2.24 | 0.91 |
|
||||||
| modified beam search | 2.24 |
|
| fast beam search | 2.35 | 0.97 |
|
||||||
| fast beam search | 2.35 |
|
|
||||||
|
|
||||||
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
||||||
|
@ -15,7 +15,6 @@ The WERs are
|
|||||||
| | dev | val | comment |
|
| | dev | val | comment |
|
||||||
|---------------------------|------------|------------|------------------------------------------|
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
||||||
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
|
||||||
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||||
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
||||||
|
|
||||||
@ -49,14 +48,6 @@ The decoding command is:
|
|||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
# beam search
|
|
||||||
./pruned_transducer_stateless2/decode.py \
|
|
||||||
--iter 696000 --avg 10 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
|
||||||
--max-duration 100 \
|
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
# modified beam search
|
# modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--iter 696000 --avg 10 \
|
--iter 696000 --avg 10 \
|
||||||
|
@ -69,7 +69,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import SPGISpeechAsrDataModule
|
from asr_datamodule import SPGISpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -252,7 +252,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -400,6 +400,7 @@ def save_results(
|
|||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_cers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
@ -409,31 +410,56 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
wers_filename = (
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(wers_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
# we also compute CER for spgispeech dataset.
|
||||||
|
results_char = []
|
||||||
|
for res in results:
|
||||||
|
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||||
|
cers_filename = (
|
||||||
|
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(cers_filename, "w") as f:
|
||||||
|
cer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_cers[key] = cer
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
logging.info("Wrote detailed error stats to {}".format(wers_filename))
|
||||||
|
|
||||||
|
test_set_wers = {
|
||||||
|
k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
}
|
||||||
|
test_set_cers = {
|
||||||
|
k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||||
|
}
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print(
|
||||||
|
"{}\t{}\t{}".format(
|
||||||
|
key, test_set_wers[key], test_set_cers[key]
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += "{}\t{}\t{}{}\n".format(
|
||||||
|
key, test_set_wers[key], test_set_cers[key], note
|
||||||
|
)
|
||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user