mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
223 lines
5.8 KiB
Python
223 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
|
# Johns Hopkins University (authors: Amir Hussein)
|
|
|
|
"""
|
|
Compute WER per language
|
|
"""
|
|
|
|
import argparse
|
|
import codecs
|
|
import math
|
|
import pickle
|
|
import re
|
|
import sys
|
|
import unicodedata
|
|
from collections import Counter, defaultdict
|
|
|
|
from kaldialign import align
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--rec",
|
|
type=str,
|
|
default="",
|
|
help="Cut ref file",
|
|
)
|
|
return parser
|
|
|
|
|
|
lids = "en,zh"
|
|
lids_dict = {lid: id + 1 for id, lid in enumerate(lids.split(","))}
|
|
id2lang = {id + 1: lid for id, lid in enumerate(lids.split(","))}
|
|
bad_id = []
|
|
|
|
|
|
def extract_info(line, info):
|
|
# Split the line at the first colon to separate the ID
|
|
id_part, rest = line.split(":", 1)
|
|
|
|
# Extract 'ref' by finding its start and end
|
|
ref_start = rest.find(info)
|
|
ref_end = rest.find("]", ref_start)
|
|
ref = rest[ref_start + len(info) : ref_end].replace("'", "").split(", ")
|
|
|
|
# Extract 'lid'
|
|
if "lid=" in rest:
|
|
lid_start = rest.find("lid=[")
|
|
lid_end = rest.find("]", lid_start)
|
|
lid = rest[lid_start + len("lid=[") : lid_end].split(", ")
|
|
else:
|
|
lid = [""]
|
|
|
|
if lid[0] == "":
|
|
bad_id.append(id_part)
|
|
if " ".join(lid):
|
|
lid = [int(i) for i in lid] # Convert each element to integer
|
|
return id_part.strip(), ref, lid
|
|
|
|
|
|
def is_English(c):
|
|
"""check character is in English"""
|
|
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
|
|
|
|
|
|
def get_en(text):
|
|
res = []
|
|
for w in text:
|
|
if w:
|
|
if is_English(w[0]):
|
|
res.append(w)
|
|
else:
|
|
continue
|
|
return res
|
|
|
|
|
|
def get_zh(text):
|
|
res = []
|
|
for w in text:
|
|
if w:
|
|
if is_English(w[0]):
|
|
continue
|
|
else:
|
|
res.append(w)
|
|
return res
|
|
|
|
|
|
def extract_info_lid(line, tag):
|
|
# Split the line at the first colon to separate the ID
|
|
id_part, rest = line.split(":", 1)
|
|
|
|
# Extract 'ref' by finding its start and end
|
|
|
|
ref_start = rest.find(tag)
|
|
ref_end = rest.find("]", ref_start)
|
|
ref = rest[ref_start + len(tag) : ref_end].replace("'", "").split(", ")
|
|
|
|
return id_part.strip(), ref
|
|
|
|
|
|
def align_lid2(labels_a, labels_b, a, b):
|
|
# Alignment
|
|
EPS = "*"
|
|
ali = align(a, b, EPS, sclite_mode=True)
|
|
|
|
a2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(a, labels_a))}
|
|
b2idx = {(i, idx): j for idx, (i, j) in enumerate(zip(b, labels_b))}
|
|
# Comparing labels of aligned elements
|
|
idx_a = 0
|
|
idx_b = 0
|
|
ali_idx = 0
|
|
aligned_a = []
|
|
aligned_b = []
|
|
while idx_a < len(a) and idx_b < len(b) and ali_idx < len(ali):
|
|
elem_a, elem_b = ali[ali_idx]
|
|
if elem_a == EPS:
|
|
idx_b += 1
|
|
elif elem_b == EPS:
|
|
idx_a += 1
|
|
elif elem_a != EPS and elem_b != EPS:
|
|
|
|
label_a = a2idx[(elem_a, idx_a)]
|
|
label_b = b2idx[(elem_b, idx_b)]
|
|
aligned_a.append(label_a)
|
|
aligned_b.append(label_b)
|
|
idx_b += 1
|
|
idx_a += 1
|
|
|
|
ali_idx += 1
|
|
return aligned_a, aligned_b
|
|
|
|
|
|
def align_lid(labels_a, labels_b):
|
|
# Alignment
|
|
res_a, res_b = [], []
|
|
EPS = "*"
|
|
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
|
|
|
|
# Comparing labels of aligned elements
|
|
for val_a, val_b in ali:
|
|
res_a.append(val_a)
|
|
res_b.append(val_b)
|
|
return res_a, res_b
|
|
|
|
|
|
def read_file(infile, tag):
|
|
""" "returns list of dict (id, lid, text)"""
|
|
res = []
|
|
with open(infile, "r") as file:
|
|
for line in file:
|
|
_, rest = line.split(":", 1)
|
|
if tag in rest:
|
|
_id, text = extract_info_lid(line, tag)
|
|
|
|
res.append((_id, text))
|
|
return res
|
|
|
|
|
|
def wer(results, sclite_mode=False):
|
|
subs = defaultdict(int)
|
|
ins = defaultdict(int)
|
|
dels = defaultdict(int)
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
for cut_id, ref, hyp in results:
|
|
ali = align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for _, r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
print(f"%WER = {tot_err_rate}")
|
|
return tot_err_rate
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
ref_data = read_file(args.rec, "ref=[")
|
|
ref_data = sorted(ref_data)
|
|
hyp_data = read_file(args.rec, "hyp=[")
|
|
hyp_data = sorted(hyp_data)
|
|
results = defaultdict(list)
|
|
|
|
for (ref, hyp) in zip(ref_data, hyp_data):
|
|
assert ref[0] == hyp[0], f"ref_id: {ref[0]} != hyp_id: {hyp[0]}"
|
|
_, text_ref = ref
|
|
_, hyp_text = hyp
|
|
if ref:
|
|
ref_text_en = get_en(text_ref)
|
|
ref_text_zh = get_zh(text_ref)
|
|
if hyp:
|
|
hyp_text_en = get_en(hyp_text)
|
|
hyp_text_zh = get_zh(hyp_text)
|
|
|
|
results["en"].append((ref[0], ref_text_en, hyp_text_en))
|
|
results["zh"].append((ref[0], ref_text_zh, hyp_text_zh))
|
|
|
|
for key, val in results.items():
|
|
print(key)
|
|
res = wer(val)
|