icefall/egs/csj/ASR/local/disfluent_recogs_to_fluent.py
Teo Wen Shen e63a8c27f8
CSJ pruned_transducer_stateless7_streaming (#892)
* update manifest stats

* update transcript configs

* lang_char and compute_fbanks

* save cuts in fbank_dir

* add core codes

* update decode.py

* Create local/utils

* tidy up

* parse raw in prepare_lang_char.py

* update manifest stats

* update transcript configs

* lang_char and compute_fbanks

* save cuts in fbank_dir

* add core codes

* update decode.py

* Create local/utils

* tidy up

* parse raw in prepare_lang_char.py

* working train

* Add compare_cer_transcript.py

* fix tokenizer decode, allow d2f only

* comment cleanup

* add export files and READMEs

* reword average column

* fix comments

* Update new results
2023-02-13 22:19:50 +08:00

203 lines
5.7 KiB
Python

import argparse
from pathlib import Path
import kaldialign
from lhotse import CutSet
ARGPARSE_DESCRIPTION = """
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
compares it against a fluent transcript, and saves the results in a separate directory.
This is useful to compare disfluent models with fluent models on the same metric.
"""
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=ARGPARSE_DESCRIPTION,
)
parser.add_argument(
"--recogs",
type=Path,
required=True,
help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
)
parser.add_argument(
"--cut",
type=Path,
required=True,
help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
)
parser.add_argument(
"--res-dir", type=Path, required=True, help="Path to save results"
)
return parser.parse_args()
def d2f(stats):
"""
Compare the outputs of a disfluent model against a fluent reference.
Indicates a disfluent model's performance only on the content words
CER^d_f = (sub_f + ins + del_f) / Nf
"""
return stats["base"] / stats["Nf"]
def calc_cer(refs, hyps):
subs = {
"F": 0,
"D": 0,
}
ins = 0
dels = {
"F": 0,
"D": 0,
}
cors = {
"F": 0,
"D": 0,
}
dis_ref_len = 0
flu_ref_len = 0
for ref, hyp in zip(refs, hyps):
assert (
ref[0] == hyp[0]
), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
tag = ref[2].copy()
ref = ref[1]
dis_ref_len += len(ref)
# Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
hyp = hyp[1]
ali = kaldialign.align(ref, hyp, "*")
tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
for tag, (ref_word, hyp_word) in zip(tags, ali):
if "D" in tag or "F" in tag:
tag = "D"
else:
tag = "F"
if ref_word == "*":
ins += 1
elif hyp_word == "*":
dels[tag] += 1
elif ref_word != hyp_word:
subs[tag] += 1
else:
cors[tag] += 1
return {
"subs": subs,
"ins": ins,
"dels": dels,
"cors": cors,
"dis_ref_len": dis_ref_len,
"flu_ref_len": flu_ref_len,
}
def for_each_recogs(recogs_file: Path, refs, out_dir):
hyps = []
with recogs_file.open() as fin:
for line in fin:
if "ref" in line:
continue
cutid, hyp = line.split(":\thyp=")
hyps.append((cutid, eval(hyp)))
assert len(refs) == len(
hyps
), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
stats = calc_cer(refs, hyps)
stat_table = ["tag,yes,no"]
for cer_type in ["subs", "dels", "cors", "ins"]:
ret = f"{cer_type}"
for df in ["D", "F"]:
try:
ret += f",{stats[cer_type][df]}"
except TypeError:
# insertions do not belong to F or D, and is not subscriptable.
ret += f",{stats[cer_type]},"
break
stat_table.append(ret)
stat_table = "\n".join(stat_table)
stats = {
"subd": stats["subs"]["D"],
"deld": stats["dels"]["D"],
"cord": stats["cors"]["D"],
"Nf": stats["flu_ref_len"],
"base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
}
cer = d2f(stats)
results = [
f"{cer:.2%}",
f"Nf,{stats['Nf']}",
]
results = "\n".join(results)
with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
fout.write(results)
fout.write("\n\n")
fout.write(stat_table)
def main():
args = get_args()
recogs_file: Path = args.recogs
assert (
recogs_file.is_file() or recogs_file.is_dir()
), f"recogs_file cannot be found at {recogs_file}."
args.res_dir.mkdir(parents=True, exist_ok=True)
if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
assert (
"csj_cuts" in args.cut.name
), f"Expected {args.cut} to be a cuts manifest."
refs: CutSet = CutSet.from_file(args.cut)
refs = sorted(
[
(
e.id,
list(e.supervisions[0].custom["disfluent"]),
e.supervisions[0].custom["disfluent_tag"].split(","),
)
for e in refs
],
key=lambda x: x[0],
)
for_each_recogs(recogs_file, refs, args.res_dir)
elif recogs_file.is_dir():
recogs_file_path = recogs_file
for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
refs = sorted(
[
(
r.id,
list(r.supervisions[0].custom["disfluent"]),
r.supervisions[0].custom["disfluent_tag"].split(","),
)
for r in refs
],
key=lambda x: x[0],
)
for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
for_each_recogs(recogs_file, refs, args.res_dir)
else:
raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")
if __name__ == "__main__":
main()