mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * working train * Add compare_cer_transcript.py * fix tokenizer decode, allow d2f only * comment cleanup * add export files and READMEs * reword average column * fix comments * Update new results
203 lines
5.7 KiB
Python
203 lines
5.7 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
import kaldialign
|
|
from lhotse import CutSet
|
|
|
|
ARGPARSE_DESCRIPTION = """
|
|
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
|
|
compares it against a fluent transcript, and saves the results in a separate directory.
|
|
This is useful to compare disfluent models with fluent models on the same metric.
|
|
|
|
"""
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
description=ARGPARSE_DESCRIPTION,
|
|
)
|
|
parser.add_argument(
|
|
"--recogs",
|
|
type=Path,
|
|
required=True,
|
|
help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
|
|
)
|
|
parser.add_argument(
|
|
"--cut",
|
|
type=Path,
|
|
required=True,
|
|
help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
|
|
)
|
|
parser.add_argument(
|
|
"--res-dir", type=Path, required=True, help="Path to save results"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def d2f(stats):
|
|
"""
|
|
Compare the outputs of a disfluent model against a fluent reference.
|
|
Indicates a disfluent model's performance only on the content words
|
|
|
|
CER^d_f = (sub_f + ins + del_f) / Nf
|
|
|
|
"""
|
|
return stats["base"] / stats["Nf"]
|
|
|
|
|
|
def calc_cer(refs, hyps):
|
|
subs = {
|
|
"F": 0,
|
|
"D": 0,
|
|
}
|
|
ins = 0
|
|
dels = {
|
|
"F": 0,
|
|
"D": 0,
|
|
}
|
|
cors = {
|
|
"F": 0,
|
|
"D": 0,
|
|
}
|
|
dis_ref_len = 0
|
|
flu_ref_len = 0
|
|
|
|
for ref, hyp in zip(refs, hyps):
|
|
assert (
|
|
ref[0] == hyp[0]
|
|
), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
|
|
tag = ref[2].copy()
|
|
ref = ref[1]
|
|
dis_ref_len += len(ref)
|
|
# Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
|
|
flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
|
|
hyp = hyp[1]
|
|
ali = kaldialign.align(ref, hyp, "*")
|
|
tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
|
|
for tag, (ref_word, hyp_word) in zip(tags, ali):
|
|
if "D" in tag or "F" in tag:
|
|
tag = "D"
|
|
else:
|
|
tag = "F"
|
|
|
|
if ref_word == "*":
|
|
ins += 1
|
|
elif hyp_word == "*":
|
|
dels[tag] += 1
|
|
elif ref_word != hyp_word:
|
|
subs[tag] += 1
|
|
else:
|
|
cors[tag] += 1
|
|
|
|
return {
|
|
"subs": subs,
|
|
"ins": ins,
|
|
"dels": dels,
|
|
"cors": cors,
|
|
"dis_ref_len": dis_ref_len,
|
|
"flu_ref_len": flu_ref_len,
|
|
}
|
|
|
|
|
|
def for_each_recogs(recogs_file: Path, refs, out_dir):
|
|
hyps = []
|
|
with recogs_file.open() as fin:
|
|
for line in fin:
|
|
if "ref" in line:
|
|
continue
|
|
cutid, hyp = line.split(":\thyp=")
|
|
hyps.append((cutid, eval(hyp)))
|
|
|
|
assert len(refs) == len(
|
|
hyps
|
|
), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
|
|
stats = calc_cer(refs, hyps)
|
|
stat_table = ["tag,yes,no"]
|
|
|
|
for cer_type in ["subs", "dels", "cors", "ins"]:
|
|
ret = f"{cer_type}"
|
|
for df in ["D", "F"]:
|
|
try:
|
|
ret += f",{stats[cer_type][df]}"
|
|
except TypeError:
|
|
# insertions do not belong to F or D, and is not subscriptable.
|
|
ret += f",{stats[cer_type]},"
|
|
break
|
|
stat_table.append(ret)
|
|
stat_table = "\n".join(stat_table)
|
|
|
|
stats = {
|
|
"subd": stats["subs"]["D"],
|
|
"deld": stats["dels"]["D"],
|
|
"cord": stats["cors"]["D"],
|
|
"Nf": stats["flu_ref_len"],
|
|
"base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
|
|
}
|
|
|
|
cer = d2f(stats)
|
|
results = [
|
|
f"{cer:.2%}",
|
|
f"Nf,{stats['Nf']}",
|
|
]
|
|
results = "\n".join(results)
|
|
|
|
with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
|
|
fout.write(results)
|
|
fout.write("\n\n")
|
|
fout.write(stat_table)
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
recogs_file: Path = args.recogs
|
|
assert (
|
|
recogs_file.is_file() or recogs_file.is_dir()
|
|
), f"recogs_file cannot be found at {recogs_file}."
|
|
|
|
args.res_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
|
|
assert (
|
|
"csj_cuts" in args.cut.name
|
|
), f"Expected {args.cut} to be a cuts manifest."
|
|
|
|
refs: CutSet = CutSet.from_file(args.cut)
|
|
refs = sorted(
|
|
[
|
|
(
|
|
e.id,
|
|
list(e.supervisions[0].custom["disfluent"]),
|
|
e.supervisions[0].custom["disfluent_tag"].split(","),
|
|
)
|
|
for e in refs
|
|
],
|
|
key=lambda x: x[0],
|
|
)
|
|
for_each_recogs(recogs_file, refs, args.res_dir)
|
|
|
|
elif recogs_file.is_dir():
|
|
recogs_file_path = recogs_file
|
|
for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
|
refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
|
|
refs = sorted(
|
|
[
|
|
(
|
|
r.id,
|
|
list(r.supervisions[0].custom["disfluent"]),
|
|
r.supervisions[0].custom["disfluent_tag"].split(","),
|
|
)
|
|
for r in refs
|
|
],
|
|
key=lambda x: x[0],
|
|
)
|
|
for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
|
|
for_each_recogs(recogs_file, refs, args.res_dir)
|
|
|
|
else:
|
|
raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|