mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
#!/usr/bin/python
|
|
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""
|
|
This script computes CER for the decodings generated by icefall recipe
|
|
"""
|
|
|
|
import argparse
|
|
import jiwer
|
|
import os
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--dec-file",
|
|
type=str,
|
|
help="file with decoded text"
|
|
)
|
|
|
|
return parser
|
|
|
|
def cer_(file):
|
|
hyp = []
|
|
ref = []
|
|
cer_results = 0
|
|
ref_lens = 0
|
|
with open(file, 'r', encoding='utf-8') as dec:
|
|
|
|
for line in dec:
|
|
id, target = line.split('\t')
|
|
id = id[0:-2]
|
|
target, txt = target.split("=")
|
|
if target == 'ref':
|
|
words = txt.strip().strip('[]').split(', ')
|
|
word_list = [word.strip("'") for word in words]
|
|
ref.append(" ".join(word_list))
|
|
elif target == 'hyp':
|
|
words = txt.strip().strip('[]').split(', ')
|
|
word_list = [word.strip("'") for word in words]
|
|
hyp.append(" ".join(word_list))
|
|
for h, r in zip(hyp, ref):
|
|
#breakpoint()
|
|
cer_results += (jiwer.cer(r, h)*len(r))
|
|
ref_lens += len(r)
|
|
print(os.path.basename(file))
|
|
print(cer_results/ref_lens)
|
|
|
|
|
|
|
|
|
|
def main():
|
|
parse = get_args()
|
|
args = parse.parse_args()
|
|
cer_(args.dec_file)
|
|
|
|
if __name__ == "__main__":
|
|
main() |