mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
cece0e955d
commit
73a1d39c86
Binary file not shown.
236
egs/librispeech/ASR/conformer_ctc2/asr_metrics.py
Normal file
236
egs/librispeech/ASR/conformer_ctc2/asr_metrics.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
import sys
|
||||||
|
import pandas as pd
|
||||||
|
import jiwer
|
||||||
|
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
def levenshtein(u, v):
|
||||||
|
prev = None
|
||||||
|
curr = [0] + list(range(1, len(v) + 1))
|
||||||
|
# Operations: (SUB, DEL, INS)
|
||||||
|
prev_ops = None
|
||||||
|
curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
|
||||||
|
for x in range(1, len(u) + 1):
|
||||||
|
prev, curr = curr, [x] + ([None] * len(v))
|
||||||
|
prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
|
||||||
|
for y in range(1, len(v) + 1):
|
||||||
|
delcost = prev[y] + 1
|
||||||
|
addcost = curr[y - 1] + 1
|
||||||
|
subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
|
||||||
|
curr[y] = min(subcost, delcost, addcost)
|
||||||
|
if curr[y] == subcost:
|
||||||
|
(n_s, n_d, n_i) = prev_ops[y - 1]
|
||||||
|
curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
|
||||||
|
elif curr[y] == delcost:
|
||||||
|
(n_s, n_d, n_i) = prev_ops[y]
|
||||||
|
curr_ops[y] = (n_s, n_d + 1, n_i)
|
||||||
|
else:
|
||||||
|
(n_s, n_d, n_i) = curr_ops[y - 1]
|
||||||
|
curr_ops[y] = (n_s, n_d, n_i + 1)
|
||||||
|
return curr[len(v)], curr_ops[len(v)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_unicode_code(text):
|
||||||
|
result = ''.join( char if ord(char) < 128 else '\\u'+format(ord(char), 'x') for char in text )
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_cer(
|
||||||
|
reference : str, transcription : str
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""
|
||||||
|
소스 단어를 대상 단아로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다.
|
||||||
|
hints 횟수는 소스 딘아의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다.
|
||||||
|
|
||||||
|
:param transcription: 대상 단어로 변환할 소스 문자열
|
||||||
|
:param reference: 소스 단어
|
||||||
|
:return: a tuple of #hits, #substitutions, #deletions, #insertions
|
||||||
|
"""
|
||||||
|
|
||||||
|
ref, hyp = [], []
|
||||||
|
|
||||||
|
ref.append(reference)
|
||||||
|
hyp.append(transcription)
|
||||||
|
|
||||||
|
#print("? : ", ref)
|
||||||
|
|
||||||
|
cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
|
||||||
|
sen_err = 0
|
||||||
|
|
||||||
|
for n in range(len(ref)):
|
||||||
|
# update CER statistics
|
||||||
|
_, (s, i, d) = levenshtein(hyp[n], ref[n])
|
||||||
|
cer_s += s
|
||||||
|
cer_i += i
|
||||||
|
cer_d += d
|
||||||
|
cer_n += len(ref[n])
|
||||||
|
|
||||||
|
# update SER statistics
|
||||||
|
if s + i + d > 0:
|
||||||
|
sen_err += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
print("reference : ",reference)
|
||||||
|
print("cer S : ", cer_s)
|
||||||
|
print("cer I : ", cer_i)
|
||||||
|
print("cer D : ", cer_d)
|
||||||
|
print("cer_n : ", cer_n)
|
||||||
|
|
||||||
|
|
||||||
|
if cer_n > 0:
|
||||||
|
print('CER: %g%%, SER: %g%%' % (
|
||||||
|
(100.0 * (cer_s + cer_i + cer_d)) / cer_n,
|
||||||
|
(100.0 * sen_err) / len(ref)))
|
||||||
|
'''
|
||||||
|
substitutions = cer_s
|
||||||
|
deletions = cer_d
|
||||||
|
insertions = cer_i
|
||||||
|
hits = len(reference) - (substitutions + deletions) #correct characters
|
||||||
|
|
||||||
|
return hits, substitutions, deletions, insertions
|
||||||
|
|
||||||
|
def _measure_wer(
|
||||||
|
reference : str, transcription : str
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
"""
|
||||||
|
소스 문자열을 대상 문자열로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다.
|
||||||
|
hints 횟수는 소스 문자열의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다.
|
||||||
|
|
||||||
|
:param transcription: 대상 단어
|
||||||
|
:param reference: 소스 단어
|
||||||
|
:return: a tuple of #hits, #substitutions, #deletions, #insertions
|
||||||
|
"""
|
||||||
|
|
||||||
|
ref, hyp = [], []
|
||||||
|
|
||||||
|
ref.append(reference)
|
||||||
|
hyp.append(transcription)
|
||||||
|
|
||||||
|
#print("? : ", ref)
|
||||||
|
|
||||||
|
wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0
|
||||||
|
sen_err = 0
|
||||||
|
|
||||||
|
for n in range(len(ref)):
|
||||||
|
# update WER statistics
|
||||||
|
_, (s, i, d) = levenshtein(hyp[n].split(), ref[n].split())
|
||||||
|
wer_s += s
|
||||||
|
wer_i += i
|
||||||
|
wer_d += d
|
||||||
|
wer_n += len(ref[n].split())
|
||||||
|
# update SER statistics
|
||||||
|
if s + i + d > 0:
|
||||||
|
sen_err += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#print("reference : ",reference)
|
||||||
|
#print("reference cnt : ", reference.split())
|
||||||
|
#print("wer S : ", wer_s)
|
||||||
|
#print("wer I : ", wer_i)
|
||||||
|
#print("wer D : ", wer_d)
|
||||||
|
#print("wer_n : ", wer_n)
|
||||||
|
|
||||||
|
|
||||||
|
if wer_n > 0:
|
||||||
|
print('WER: %g%%, SER: %g%%' % (
|
||||||
|
(100.0 * (wer_s + wer_i + wer_d)) / wer_n,
|
||||||
|
(100.0 * sen_err) / len(ref)))
|
||||||
|
|
||||||
|
substitutions = wer_s
|
||||||
|
deletions = wer_d
|
||||||
|
insertions = wer_i
|
||||||
|
hits = len(reference.split()) - (substitutions + deletions) #correct words between refs and trans
|
||||||
|
|
||||||
|
return hits, substitutions, deletions, insertions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_er(
|
||||||
|
reference : str, transcription : str
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
TBD
|
||||||
|
:param transcription: 대상 문자열로 변환할 소스 문자열
|
||||||
|
:param reference:
|
||||||
|
:return: a tuple of #
|
||||||
|
"""
|
||||||
|
TBD1 =""
|
||||||
|
TBD2 =""
|
||||||
|
return TBD1, TBD2
|
||||||
|
|
||||||
|
|
||||||
|
def get_cer(reference, transcription, rm_punctuation = True
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
|
||||||
|
# 문자 오류율(CER)은 자동 음성 인식 시스템의 성능에 대한 일반적인 메트릭입니다.
|
||||||
|
# CER은 WER(단어 오류율)과 유사하지만 단어 대신 문자에 대해 작동합니다.
|
||||||
|
# 이 코드에서는 문제는 사람들이 띄어쓰기를 지키지 않고 작성한 텍스트를 컴퓨터가 정확하게 인식하는 것이 매우 어렵기 때문에 인식에러에서 생략합니다.
|
||||||
|
# CER의 출력은 특히 삽입 수가 많은 경우 항상 0과 1 사이의 숫자가 아닙니다. 이 값은 종종 잘못 예측된 문자의 백분율과 연관됩니다. 값이 낮을수록 좋습니다.
|
||||||
|
# CER이 0인 ASR 시스템의 성능은 완벽한 점수입니다.
|
||||||
|
|
||||||
|
# CER = (S + D + I) / N = (S + D + I) / (S + D + C)
|
||||||
|
# S is the number of the substitutions,
|
||||||
|
# D is the number of the deletions,
|
||||||
|
# I is the number of the insertions,
|
||||||
|
# C is the number of the correct characters,
|
||||||
|
# N is the number of the characters in the reference (N=S+D+C).
|
||||||
|
|
||||||
|
refs = jiwer.RemoveWhiteSpace(replace_by_space=False)(reference)
|
||||||
|
trans = jiwer.RemoveWhiteSpace(replace_by_space=False)(transcription)
|
||||||
|
|
||||||
|
if rm_punctuation == True:
|
||||||
|
refs = jiwer.RemovePunctuation()(refs)
|
||||||
|
trans = jiwer.RemovePunctuation()(trans)
|
||||||
|
else:
|
||||||
|
refs = reference
|
||||||
|
trans = transcription
|
||||||
|
|
||||||
|
#print("refs : ", refs)
|
||||||
|
|
||||||
|
[hits ,cer_s, cer_d, cer_i] = _measure_cer(refs, trans)
|
||||||
|
|
||||||
|
substitutions = cer_s
|
||||||
|
deletions = cer_d
|
||||||
|
insertions = cer_i
|
||||||
|
#print("tmp hits : ", hits)
|
||||||
|
incorrect = substitutions + deletions + insertions
|
||||||
|
total = substitutions + deletions + hits + insertions
|
||||||
|
|
||||||
|
cer = incorrect / total
|
||||||
|
return cer, substitutions, deletions, insertions
|
||||||
|
|
||||||
|
|
||||||
|
def get_wer(reference, transcription, rm_punctuation = True
|
||||||
|
)-> Tuple[int, int, int, int]:
|
||||||
|
|
||||||
|
# WER = (S + D + I) / N = (S + D + I) / (S + D + C)
|
||||||
|
# S is the number of the substitutions,
|
||||||
|
# D is the number of the deletions,
|
||||||
|
# I is the number of the insertions,
|
||||||
|
# C is the number of the correct words,
|
||||||
|
# N is the number of the words in the reference (N=S+D+C).
|
||||||
|
if rm_punctuation == True:
|
||||||
|
refs = jiwer.RemovePunctuation()(reference)
|
||||||
|
trans = jiwer.RemovePunctuation()(transcription)
|
||||||
|
else:
|
||||||
|
refs = reference
|
||||||
|
trans = transcription
|
||||||
|
[hits, wer_s, wer_d, wer_i] = _measure_wer(refs, trans)
|
||||||
|
|
||||||
|
substitutions = wer_s
|
||||||
|
deletions = wer_d
|
||||||
|
insertions = wer_i
|
||||||
|
#print("tmp hits : ", hits)
|
||||||
|
incorrect = substitutions + deletions + insertions
|
||||||
|
total = substitutions + deletions + hits + insertions
|
||||||
|
|
||||||
|
wer = incorrect / total
|
||||||
|
return wer, substitutions, deletions, insertions
|
||||||
Loading…
x
Reference in New Issue
Block a user