mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
cb3b3da087
commit
61acac6c02
Binary file not shown.
236
egs/aishell/ASR/transformer_ctc/asr_metrics.py
Normal file
236
egs/aishell/ASR/transformer_ctc/asr_metrics.py
Normal file
@ -0,0 +1,236 @@
|
||||
from __future__ import unicode_literals
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
import sys
|
||||
import pandas as pd
|
||||
import jiwer
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
def levenshtein(u, v):
|
||||
prev = None
|
||||
curr = [0] + list(range(1, len(v) + 1))
|
||||
# Operations: (SUB, DEL, INS)
|
||||
prev_ops = None
|
||||
curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
|
||||
for x in range(1, len(u) + 1):
|
||||
prev, curr = curr, [x] + ([None] * len(v))
|
||||
prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
|
||||
for y in range(1, len(v) + 1):
|
||||
delcost = prev[y] + 1
|
||||
addcost = curr[y - 1] + 1
|
||||
subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
|
||||
curr[y] = min(subcost, delcost, addcost)
|
||||
if curr[y] == subcost:
|
||||
(n_s, n_d, n_i) = prev_ops[y - 1]
|
||||
curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
|
||||
elif curr[y] == delcost:
|
||||
(n_s, n_d, n_i) = prev_ops[y]
|
||||
curr_ops[y] = (n_s, n_d + 1, n_i)
|
||||
else:
|
||||
(n_s, n_d, n_i) = curr_ops[y - 1]
|
||||
curr_ops[y] = (n_s, n_d, n_i + 1)
|
||||
return curr[len(v)], curr_ops[len(v)]
|
||||
|
||||
|
||||
|
||||
def get_unicode_code(text):
|
||||
result = ''.join( char if ord(char) < 128 else '\\u'+format(ord(char), 'x') for char in text )
|
||||
return result
|
||||
|
||||
|
||||
def _measure_cer(
|
||||
reference : str, transcription : str
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
소스 단어를 대상 단아로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다.
|
||||
hints 횟수는 소스 딘아의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다.
|
||||
|
||||
:param transcription: 대상 단어로 변환할 소스 문자열
|
||||
:param reference: 소스 단어
|
||||
:return: a tuple of #hits, #substitutions, #deletions, #insertions
|
||||
"""
|
||||
|
||||
ref, hyp = [], []
|
||||
|
||||
ref.append(reference)
|
||||
hyp.append(transcription)
|
||||
|
||||
#print("? : ", ref)
|
||||
|
||||
cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
|
||||
sen_err = 0
|
||||
|
||||
for n in range(len(ref)):
|
||||
# update CER statistics
|
||||
_, (s, i, d) = levenshtein(hyp[n], ref[n])
|
||||
cer_s += s
|
||||
cer_i += i
|
||||
cer_d += d
|
||||
cer_n += len(ref[n])
|
||||
|
||||
# update SER statistics
|
||||
if s + i + d > 0:
|
||||
sen_err += 1
|
||||
|
||||
|
||||
|
||||
'''
|
||||
print("reference : ",reference)
|
||||
print("cer S : ", cer_s)
|
||||
print("cer I : ", cer_i)
|
||||
print("cer D : ", cer_d)
|
||||
print("cer_n : ", cer_n)
|
||||
|
||||
|
||||
if cer_n > 0:
|
||||
print('CER: %g%%, SER: %g%%' % (
|
||||
(100.0 * (cer_s + cer_i + cer_d)) / cer_n,
|
||||
(100.0 * sen_err) / len(ref)))
|
||||
'''
|
||||
substitutions = cer_s
|
||||
deletions = cer_d
|
||||
insertions = cer_i
|
||||
hits = len(reference) - (substitutions + deletions) #correct characters
|
||||
|
||||
return hits, substitutions, deletions, insertions
|
||||
|
||||
def _measure_wer(
|
||||
reference : str, transcription : str
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
소스 문자열을 대상 문자열로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다.
|
||||
hints 횟수는 소스 문자열의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다.
|
||||
|
||||
:param transcription: 대상 단어
|
||||
:param reference: 소스 단어
|
||||
:return: a tuple of #hits, #substitutions, #deletions, #insertions
|
||||
"""
|
||||
|
||||
ref, hyp = [], []
|
||||
|
||||
ref.append(reference)
|
||||
hyp.append(transcription)
|
||||
|
||||
#print("? : ", ref)
|
||||
|
||||
wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0
|
||||
sen_err = 0
|
||||
|
||||
for n in range(len(ref)):
|
||||
# update WER statistics
|
||||
_, (s, i, d) = levenshtein(hyp[n].split(), ref[n].split())
|
||||
wer_s += s
|
||||
wer_i += i
|
||||
wer_d += d
|
||||
wer_n += len(ref[n].split())
|
||||
# update SER statistics
|
||||
if s + i + d > 0:
|
||||
sen_err += 1
|
||||
|
||||
|
||||
|
||||
#print("reference : ",reference)
|
||||
#print("reference cnt : ", reference.split())
|
||||
#print("wer S : ", wer_s)
|
||||
#print("wer I : ", wer_i)
|
||||
#print("wer D : ", wer_d)
|
||||
#print("wer_n : ", wer_n)
|
||||
|
||||
|
||||
if wer_n > 0:
|
||||
print('WER: %g%%, SER: %g%%' % (
|
||||
(100.0 * (wer_s + wer_i + wer_d)) / wer_n,
|
||||
(100.0 * sen_err) / len(ref)))
|
||||
|
||||
substitutions = wer_s
|
||||
deletions = wer_d
|
||||
insertions = wer_i
|
||||
hits = len(reference.split()) - (substitutions + deletions) #correct words between refs and trans
|
||||
|
||||
return hits, substitutions, deletions, insertions
|
||||
|
||||
|
||||
|
||||
|
||||
def _measure_er(
|
||||
reference : str, transcription : str
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
TBD
|
||||
:param transcription: 대상 문자열로 변환할 소스 문자열
|
||||
:param reference:
|
||||
:return: a tuple of #
|
||||
"""
|
||||
TBD1 =""
|
||||
TBD2 =""
|
||||
return TBD1, TBD2
|
||||
|
||||
|
||||
def get_cer(reference, transcription, rm_punctuation = True
|
||||
) -> Tuple[int, int, int, int]:
|
||||
|
||||
# 문자 오류율(CER)은 자동 음성 인식 시스템의 성능에 대한 일반적인 메트릭입니다.
|
||||
# CER은 WER(단어 오류율)과 유사하지만 단어 대신 문자에 대해 작동합니다.
|
||||
# 이 코드에서는 문제는 사람들이 띄어쓰기를 지키지 않고 작성한 텍스트를 컴퓨터가 정확하게 인식하는 것이 매우 어렵기 때문에 인식에러에서 생략합니다.
|
||||
# CER의 출력은 특히 삽입 수가 많은 경우 항상 0과 1 사이의 숫자가 아닙니다. 이 값은 종종 잘못 예측된 문자의 백분율과 연관됩니다. 값이 낮을수록 좋습니다.
|
||||
# CER이 0인 ASR 시스템의 성능은 완벽한 점수입니다.
|
||||
|
||||
# CER = (S + D + I) / N = (S + D + I) / (S + D + C)
|
||||
# S is the number of the substitutions,
|
||||
# D is the number of the deletions,
|
||||
# I is the number of the insertions,
|
||||
# C is the number of the correct characters,
|
||||
# N is the number of the characters in the reference (N=S+D+C).
|
||||
|
||||
refs = jiwer.RemoveWhiteSpace(replace_by_space=False)(reference)
|
||||
trans = jiwer.RemoveWhiteSpace(replace_by_space=False)(transcription)
|
||||
|
||||
if rm_punctuation == True:
|
||||
refs = jiwer.RemovePunctuation()(refs)
|
||||
trans = jiwer.RemovePunctuation()(trans)
|
||||
else:
|
||||
refs = reference
|
||||
trans = transcription
|
||||
|
||||
#print("refs : ", refs)
|
||||
|
||||
[hits ,cer_s, cer_d, cer_i] = _measure_cer(refs, trans)
|
||||
|
||||
substitutions = cer_s
|
||||
deletions = cer_d
|
||||
insertions = cer_i
|
||||
#print("tmp hits : ", hits)
|
||||
incorrect = substitutions + deletions + insertions
|
||||
total = substitutions + deletions + hits + insertions
|
||||
|
||||
cer = incorrect / total
|
||||
return cer, substitutions, deletions, insertions
|
||||
|
||||
|
||||
def get_wer(reference, transcription, rm_punctuation = True
|
||||
)-> Tuple[int, int, int, int]:
|
||||
|
||||
# WER = (S + D + I) / N = (S + D + I) / (S + D + C)
|
||||
# S is the number of the substitutions,
|
||||
# D is the number of the deletions,
|
||||
# I is the number of the insertions,
|
||||
# C is the number of the correct words,
|
||||
# N is the number of the words in the reference (N=S+D+C).
|
||||
if rm_punctuation == True:
|
||||
refs = jiwer.RemovePunctuation()(reference)
|
||||
trans = jiwer.RemovePunctuation()(transcription)
|
||||
else:
|
||||
refs = reference
|
||||
trans = transcription
|
||||
[hits, wer_s, wer_d, wer_i] = _measure_wer(refs, trans)
|
||||
|
||||
substitutions = wer_s
|
||||
deletions = wer_d
|
||||
insertions = wer_i
|
||||
#print("tmp hits : ", hits)
|
||||
incorrect = substitutions + deletions + insertions
|
||||
total = substitutions + deletions + hits + insertions
|
||||
|
||||
wer = incorrect / total
|
||||
return wer, substitutions, deletions, insertions
|
||||
Loading…
x
Reference in New Issue
Block a user