icefall/egs/seame/ASR/local/wer_lang.py
2024-04-05 13:00:29 -04:00

218 lines
5.8 KiB
Python

#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
"""
Compute WER per language
"""
import sys, codecs, math, pickle, unicodedata, re
from collections import Counter
import argparse
from kaldialign import align
from collections import defaultdict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Cut ref file",
)
return parser
lids = "en,zh"
lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
bad_id = []
def extract_info(line, info):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(info)
ref_end = rest.find("]", ref_start)
ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
# Extract 'lid'
if "lid=" in rest:
lid_start = rest.find("lid=[")
lid_end = rest.find("]", lid_start)
lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
else:
lid = [""]
if lid[0] == "":
bad_id.append(id_part)
if " ".join(lid):
lid = [int(i) for i in lid] # Convert each element to integer
return id_part.strip(), ref, lid
def is_English(c):
"""check character is in English"""
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
def get_en(text):
res = []
for w in text:
if w:
if is_English(w[0]):
res.append(w)
else:
continue
return res
def get_zh(text):
res = []
for w in text:
if w:
if is_English(w[0]):
continue
else:
res.append(w)
return res
def extract_info_lid(line, tag):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(":", 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(tag)
ref_end = rest.find("]", ref_start)
ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
return id_part.strip(), ref
def align_lid2(labels_a, labels_b, a, b):
# Alignment
EPS = "*"
ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
# Comparing labels of aligned elements
idx_a = 0
idx_b = 0
ali_idx = 0
aligned_a = []
aligned_b = []
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
elem_a, elem_b = ali[ali_idx]
if elem_a == EPS:
idx_b += 1
elif elem_b == EPS:
idx_a += 1
elif elem_a != EPS and elem_b != EPS:
label_a = a2idx[(elem_a, idx_a)]
label_b = b2idx[(elem_b, idx_b)]
aligned_a.append(label_a)
aligned_b.append(label_b)
idx_b += 1
idx_a += 1
ali_idx += 1
return aligned_a, aligned_b
def align_lid(labels_a, labels_b):
# Alignment
res_a, res_b = [], []
EPS = "*"
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
# Comparing labels of aligned elements
for val_a, val_b in ali:
res_a.append(val_a)
res_b.append(val_b)
return res_a, res_b
def read_file(infile, tag):
""" "returns list of dict (id, lid, text)"""
res = []
with open(infile, "r") as file:
for line in file:
_, rest = line.split(":", 1)
if tag in rest:
_id, text = extract_info_lid(line, tag)
res.append((_id, text))
return res
def wer(results, sclite_mode=False):
subs = defaultdict(int)
ins = defaultdict(int)
dels = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp in results:
ali = align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
print(f"%WER = {tot_err_rate}")
return tot_err_rate
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
ref_data = read_file(args.rec, "ref=[")
ref_data = sorted(ref_data)
hyp_data = read_file(args.rec, "hyp=[")
hyp_data = sorted(hyp_data)
results = defaultdict(list)
for (ref, hyp) in zip(ref_data, hyp_data):
assert ref[0] == hyp[0], f"ref_id: {ref[0]} != hyp_id: {hyp[0]}"
_, text_ref = ref
_, hyp_text = hyp
if ref:
ref_text_en = get_en(text_ref)
ref_text_zh = get_zh(text_ref)
if hyp:
hyp_text_en = get_en(hyp_text)
hyp_text_zh = get_zh(hyp_text)
results["en"].append((ref[0], ref_text_en, hyp_text_en))
results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
for key, val in results.items():
print(key)
res = wer(val)