import argparse from pathlib import Path import kaldialign from lhotse import CutSet ARGPARSE_DESCRIPTION = """ This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, compares it against a fluent transcript, and saves the results in a separate directory. This is useful to compare disfluent models with fluent models on the same metric. """ def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=ARGPARSE_DESCRIPTION, ) parser.add_argument( "--recogs", type=Path, required=True, help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", ) parser.add_argument( "--cut", type=Path, required=True, help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", ) parser.add_argument( "--res-dir", type=Path, required=True, help="Path to save results" ) return parser.parse_args() def d2f(stats): """ Compare the outputs of a disfluent model against a fluent reference. Indicates a disfluent model's performance only on the content words CER^d_f = (sub_f + ins + del_f) / Nf """ return stats["base"] / stats["Nf"] def calc_cer(refs, hyps): subs = { "F": 0, "D": 0, } ins = 0 dels = { "F": 0, "D": 0, } cors = { "F": 0, "D": 0, } dis_ref_len = 0 flu_ref_len = 0 for ref, hyp in zip(refs, hyps): assert ( ref[0] == hyp[0] ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." tag = ref[2].copy() ref = ref[1] dis_ref_len += len(ref) # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) hyp = hyp[1] ali = kaldialign.align(ref, hyp, "*") tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] for tag, (ref_word, hyp_word) in zip(tags, ali): if "D" in tag or "F" in tag: tag = "D" else: tag = "F" if ref_word == "*": ins += 1 elif hyp_word == "*": dels[tag] += 1 elif ref_word != hyp_word: subs[tag] += 1 else: cors[tag] += 1 return { "subs": subs, "ins": ins, "dels": dels, "cors": cors, "dis_ref_len": dis_ref_len, "flu_ref_len": flu_ref_len, } def for_each_recogs(recogs_file: Path, refs, out_dir): hyps = [] with recogs_file.open() as fin: for line in fin: if "ref" in line: continue cutid, hyp = line.split(":\thyp=") hyps.append((cutid, eval(hyp))) assert len(refs) == len( hyps ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." stats = calc_cer(refs, hyps) stat_table = ["tag,yes,no"] for cer_type in ["subs", "dels", "cors", "ins"]: ret = f"{cer_type}" for df in ["D", "F"]: try: ret += f",{stats[cer_type][df]}" except TypeError: # insertions do not belong to F or D, and is not subscriptable. ret += f",{stats[cer_type]}," break stat_table.append(ret) stat_table = "\n".join(stat_table) stats = { "subd": stats["subs"]["D"], "deld": stats["dels"]["D"], "cord": stats["cors"]["D"], "Nf": stats["flu_ref_len"], "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], } cer = d2f(stats) results = [ f"{cer:.2%}", f"Nf,{stats['Nf']}", ] results = "\n".join(results) with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: fout.write(results) fout.write("\n\n") fout.write(stat_table) def main(): args = get_args() recogs_file: Path = args.recogs assert ( recogs_file.is_file() or recogs_file.is_dir() ), f"recogs_file cannot be found at {recogs_file}." args.res_dir.mkdir(parents=True, exist_ok=True) if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): assert ( "csj_cuts" in args.cut.name ), f"Expected {args.cut} to be a cuts manifest." refs: CutSet = CutSet.from_file(args.cut) refs = sorted( [ ( e.id, list(e.supervisions[0].custom["disfluent"]), e.supervisions[0].custom["disfluent_tag"].split(","), ) for e in refs ], key=lambda x: x[0], ) for_each_recogs(recogs_file, refs, args.res_dir) elif recogs_file.is_dir(): recogs_file_path = recogs_file for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") refs = sorted( [ ( r.id, list(r.supervisions[0].custom["disfluent"]), r.supervisions[0].custom["disfluent_tag"].split(","), ) for r in refs ], key=lambda x: x[0], ) for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): for_each_recogs(recogs_file, refs, args.res_dir) else: raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") if __name__ == "__main__": main()