mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
split save_results()
-> save_asr_output()
+ save_wer_results()
(#1712)
- the idea is to support `--skip-scoring` argument passed to a decoding script - created for Transducer decoding (non-streaming, streaming) - it can be done also for CTC decoding... (not yet) - also added `--label` for extra label in `streaming_decode.py` - and also added `set_caching_enabled(True)`, which has no effect on librispeech, but it leads to faster runtime on DBs with long recordings (assuming `librispeech/zipformer` scripts are the example scripts for other setups)
This commit is contained in:
parent
3b257dd5ae
commit
1730fce688
@ -120,6 +120,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -296,6 +297,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -455,7 +463,7 @@ def decode_one_batch(
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
return {key: hyps} # note: returns words
|
||||
|
||||
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
@ -492,7 +500,7 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
@ -500,7 +508,7 @@ def decode_one_batch(
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
key = "no-rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
@ -508,11 +516,11 @@ def decode_one_batch(
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
return {key: hyps} # note: returns BPE tokens
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
@ -646,7 +654,27 @@ def decode_dataset(
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
@ -661,32 +689,30 @@ def save_results(
|
||||
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -705,6 +731,9 @@ def main():
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
@ -719,9 +748,9 @@ def main():
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -730,11 +759,11 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -940,12 +969,19 @@ def main():
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -121,6 +121,7 @@ from beam_search import (
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
@ -369,6 +370,14 @@ def get_parser():
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -590,21 +599,23 @@ def decode_one_batch(
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||
prefix = f"{params.decoding_method}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
prefix += f"_beam-{params.beam}"
|
||||
prefix += f"_max-contexts-{params.max_contexts}"
|
||||
prefix += f"_max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
prefix += f"_num-paths-{params.num_paths}"
|
||||
prefix += f"_nbest-scale-{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
return {prefix: hyps}
|
||||
elif "modified_beam_search" in params.decoding_method:
|
||||
prefix = f"beam_size_{params.beam_size}"
|
||||
prefix += f"_beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
@ -617,10 +628,11 @@ def decode_one_batch(
|
||||
return ans
|
||||
else:
|
||||
if params.has_contexts:
|
||||
prefix += f"-context-score-{params.context_score}"
|
||||
prefix += f"_context-score-{params.context_score}"
|
||||
return {prefix: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
prefix += f"_beam-size-{params.beam_size}"
|
||||
return {prefix: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -707,46 +719,58 @@ def decode_dataset(
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
|
||||
):
|
||||
"""
|
||||
Save WER and per-utterance word alignments.
|
||||
"""
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -762,6 +786,9 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
@ -783,9 +810,9 @@ def main():
|
||||
params.has_contexts = False
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -794,20 +821,20 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"_max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
params.suffix += f"_nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"_num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
@ -815,19 +842,19 @@ def main():
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"_context-{params.context_size}"
|
||||
params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_shallow_fusion:
|
||||
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||
params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -1038,12 +1065,19 @@ def main():
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
)
|
||||
|
||||
save_results(
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from lhotse import CutSet, set_caching_enabled
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -76,6 +76,13 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Extra label of the decoding run.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
@ -188,6 +195,14 @@ def get_parser():
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
)
|
||||
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -640,46 +655,60 @@ def decode_dataset(
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save WER and per-utterance word alignments.
|
||||
"""
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
|
||||
wer_filename = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
@ -694,6 +723,9 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -706,18 +738,21 @@ def main():
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"_max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if params.label:
|
||||
params.suffix += f"-{params.label}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -845,12 +880,21 @@ def main():
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user