from local

This commit is contained in:
dohe0342 2023-02-02 19:16:31 +09:00
parent cece0e955d
commit 73a1d39c86
2 changed files with 236 additions and 0 deletions

View File

@ -0,0 +1,236 @@
from __future__ import unicode_literals
import logging
from typing import Any, Dict, List, Tuple, Union
import sys
import pandas as pd
import jiwer
# -*- coding: utf-8 -*-
def levenshtein(u, v):
prev = None
curr = [0] + list(range(1, len(v) + 1))
# Operations: (SUB, DEL, INS)
prev_ops = None
curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
for x in range(1, len(u) + 1):
prev, curr = curr, [x] + ([None] * len(v))
prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
for y in range(1, len(v) + 1):
delcost = prev[y] + 1
addcost = curr[y - 1] + 1
subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
curr[y] = min(subcost, delcost, addcost)
if curr[y] == subcost:
(n_s, n_d, n_i) = prev_ops[y - 1]
curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
elif curr[y] == delcost:
(n_s, n_d, n_i) = prev_ops[y]
curr_ops[y] = (n_s, n_d + 1, n_i)
else:
(n_s, n_d, n_i) = curr_ops[y - 1]
curr_ops[y] = (n_s, n_d, n_i + 1)
return curr[len(v)], curr_ops[len(v)]
def get_unicode_code(text):
result = ''.join( char if ord(char) < 128 else '\\u'+format(ord(char), 'x') for char in text )
return result
def _measure_cer(
reference : str, transcription : str
) -> Tuple[int, int, int, int]:
"""
소스 단어를 대상 단아로 변환하는 필요한 편집 작업(삭제, 삽입, 바꾸기) 수를 확인합니다.
hints 횟수는 소스 딘아의 전체 길이에서 삭제 대체 횟수를 빼서 제공할 있습니다.
:param transcription: 대상 단어로 변환할 소스 문자열
:param reference: 소스 단어
:return: a tuple of #hits, #substitutions, #deletions, #insertions
"""
ref, hyp = [], []
ref.append(reference)
hyp.append(transcription)
#print("? : ", ref)
cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
sen_err = 0
for n in range(len(ref)):
# update CER statistics
_, (s, i, d) = levenshtein(hyp[n], ref[n])
cer_s += s
cer_i += i
cer_d += d
cer_n += len(ref[n])
# update SER statistics
if s + i + d > 0:
sen_err += 1
'''
print("reference : ",reference)
print("cer S : ", cer_s)
print("cer I : ", cer_i)
print("cer D : ", cer_d)
print("cer_n : ", cer_n)
if cer_n > 0:
print('CER: %g%%, SER: %g%%' % (
(100.0 * (cer_s + cer_i + cer_d)) / cer_n,
(100.0 * sen_err) / len(ref)))
'''
substitutions = cer_s
deletions = cer_d
insertions = cer_i
hits = len(reference) - (substitutions + deletions) #correct characters
return hits, substitutions, deletions, insertions
def _measure_wer(
reference : str, transcription : str
) -> Tuple[int, int, int, int]:
"""
소스 문자열을 대상 문자열로 변환하는 필요한 편집 작업(삭제, 삽입, 바꾸기) 수를 확인합니다.
hints 횟수는 소스 문자열의 전체 길이에서 삭제 대체 횟수를 빼서 제공할 있습니다.
:param transcription: 대상 단어
:param reference: 소스 단어
:return: a tuple of #hits, #substitutions, #deletions, #insertions
"""
ref, hyp = [], []
ref.append(reference)
hyp.append(transcription)
#print("? : ", ref)
wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0
sen_err = 0
for n in range(len(ref)):
# update WER statistics
_, (s, i, d) = levenshtein(hyp[n].split(), ref[n].split())
wer_s += s
wer_i += i
wer_d += d
wer_n += len(ref[n].split())
# update SER statistics
if s + i + d > 0:
sen_err += 1
#print("reference : ",reference)
#print("reference cnt : ", reference.split())
#print("wer S : ", wer_s)
#print("wer I : ", wer_i)
#print("wer D : ", wer_d)
#print("wer_n : ", wer_n)
if wer_n > 0:
print('WER: %g%%, SER: %g%%' % (
(100.0 * (wer_s + wer_i + wer_d)) / wer_n,
(100.0 * sen_err) / len(ref)))
substitutions = wer_s
deletions = wer_d
insertions = wer_i
hits = len(reference.split()) - (substitutions + deletions) #correct words between refs and trans
return hits, substitutions, deletions, insertions
def _measure_er(
reference : str, transcription : str
) -> Tuple[int, int]:
"""
TBD
:param transcription: 대상 문자열로 변환할 소스 문자열
:param reference:
:return: a tuple of #
"""
TBD1 =""
TBD2 =""
return TBD1, TBD2
def get_cer(reference, transcription, rm_punctuation = True
) -> Tuple[int, int, int, int]:
# 문자 오류율(CER)은 자동 음성 인식 시스템의 성능에 대한 일반적인 메트릭입니다.
# CER은 WER(단어 오류율)과 유사하지만 단어 대신 문자에 대해 작동합니다.
# 이 코드에서는 문제는 사람들이 띄어쓰기를 지키지 않고 작성한 텍스트를 컴퓨터가 정확하게 인식하는 것이 매우 어렵기 때문에 인식에러에서 생략합니다.
# CER의 출력은 특히 삽입 수가 많은 경우 항상 0과 1 사이의 숫자가 아닙니다. 이 값은 종종 잘못 예측된 문자의 백분율과 연관됩니다. 값이 낮을수록 좋습니다.
# CER이 0인 ASR 시스템의 성능은 완벽한 점수입니다.
# CER = (S + D + I) / N = (S + D + I) / (S + D + C)
# S is the number of the substitutions,
# D is the number of the deletions,
# I is the number of the insertions,
# C is the number of the correct characters,
# N is the number of the characters in the reference (N=S+D+C).
refs = jiwer.RemoveWhiteSpace(replace_by_space=False)(reference)
trans = jiwer.RemoveWhiteSpace(replace_by_space=False)(transcription)
if rm_punctuation == True:
refs = jiwer.RemovePunctuation()(refs)
trans = jiwer.RemovePunctuation()(trans)
else:
refs = reference
trans = transcription
#print("refs : ", refs)
[hits ,cer_s, cer_d, cer_i] = _measure_cer(refs, trans)
substitutions = cer_s
deletions = cer_d
insertions = cer_i
#print("tmp hits : ", hits)
incorrect = substitutions + deletions + insertions
total = substitutions + deletions + hits + insertions
cer = incorrect / total
return cer, substitutions, deletions, insertions
def get_wer(reference, transcription, rm_punctuation = True
)-> Tuple[int, int, int, int]:
# WER = (S + D + I) / N = (S + D + I) / (S + D + C)
# S is the number of the substitutions,
# D is the number of the deletions,
# I is the number of the insertions,
# C is the number of the correct words,
# N is the number of the words in the reference (N=S+D+C).
if rm_punctuation == True:
refs = jiwer.RemovePunctuation()(reference)
trans = jiwer.RemovePunctuation()(transcription)
else:
refs = reference
trans = transcription
[hits, wer_s, wer_d, wer_i] = _measure_wer(refs, trans)
substitutions = wer_s
deletions = wer_d
insertions = wer_i
#print("tmp hits : ", hits)
incorrect = substitutions + deletions + insertions
total = substitutions + deletions + hits + insertions
wer = incorrect / total
return wer, substitutions, deletions, insertions