diff --git a/egs/spgispeech/ASR/README.md b/egs/spgispeech/ASR/README.md index 67f21bba8..f60408cc1 100644 --- a/egs/spgispeech/ASR/README.md +++ b/egs/spgispeech/ASR/README.md @@ -23,11 +23,10 @@ ArXiv link: https://arxiv.org/abs/2104.02014 ## Performance Record -| Decoding method | val | -|---------------------------|------------| -| greedy search | 2.40 | -| beam search | 2.24 | -| modified beam search | 2.24 | -| fast beam search | 2.35 | +| Decoding method | val WER | val CER | +|---------------------------|------------|---------| +| greedy search | 2.40 | 0.99 | +| modified beam search | 2.24 | 0.91 | +| fast beam search | 2.35 | 0.97 | See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details. diff --git a/egs/spgispeech/ASR/RESULTS.md b/egs/spgispeech/ASR/RESULTS.md index c63b8ce90..de9e35c5a 100644 --- a/egs/spgispeech/ASR/RESULTS.md +++ b/egs/spgispeech/ASR/RESULTS.md @@ -15,7 +15,6 @@ The WERs are | | dev | val | comment | |---------------------------|------------|------------|------------------------------------------| | greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 | -| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | | modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | | fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | @@ -49,14 +48,6 @@ The decoding command is: --max-duration 100 \ --decoding-method greedy_search -# beam search -./pruned_transducer_stateless2/decode.py \ - --iter 696000 --avg 10 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 - # modified beam search ./pruned_transducer_stateless2/decode.py \ --iter 696000 --avg 10 \ diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 86626f058..ae49d166b 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -69,7 +69,7 @@ import torch.nn as nn from asr_datamodule import SPGISpeechAsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -252,7 +252,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -400,6 +400,7 @@ def save_results( results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() + test_set_cers = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" @@ -409,31 +410,56 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + wers_filename = ( + params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_filename, "w") as f: + with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + # we also compute CER for spgispeech dataset. + results_char = [] + for res in results: + results_char.append((list("".join(res[0])), list("".join(res[1])))) + cers_filename = ( + params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(cers_filename, "w") as f: + cer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_cers[key] = cer - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + logging.info("Wrote detailed error stats to {}".format(wers_filename)) + + test_set_wers = { + k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1]) + } + test_set_cers = { + k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1]) + } errs_info = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print("settings\tWER\tCER", file=f) + for key in test_set_wers: + print( + "{}\t{}\t{}".format( + key, test_set_wers[key], test_set_cers[key] + ), + file=f, + ) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + for key in test_set_wers: + s += "{}\t{}\t{}{}\n".format( + key, test_set_wers[key], test_set_cers[key], note + ) note = "" logging.info(s)