From 73a1d39c86d959597126ca12d8ee99ee1ee545c3 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 2 Feb 2023 19:16:31 +0900 Subject: [PATCH] from local --- .../ASR/transformer_ctc/.decode.py.swp | Bin 53248 -> 53248 bytes .../ASR/conformer_ctc2/asr_metrics.py | 236 ++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 egs/librispeech/ASR/conformer_ctc2/asr_metrics.py diff --git a/egs/aishell/ASR/transformer_ctc/.decode.py.swp b/egs/aishell/ASR/transformer_ctc/.decode.py.swp index ca0a287f872f27886d9a08118d2b79e6ad58407b..e80ac848bd3ab0bccf7cc48f9c4f2e225c9cea7b 100644 GIT binary patch delta 32 mcmZozz}&EaSv1KY%+puFQqO<^2m}}y%6o1n7i|=MeI5XhR|(Vr delta 32 mcmZozz}&EaSv1KY%+puFQqO<^2m}}yta@%I8*LPQeI5Xd@(EV} diff --git a/egs/librispeech/ASR/conformer_ctc2/asr_metrics.py b/egs/librispeech/ASR/conformer_ctc2/asr_metrics.py new file mode 100644 index 000000000..e197741c1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc2/asr_metrics.py @@ -0,0 +1,236 @@ +from __future__ import unicode_literals +import logging +from typing import Any, Dict, List, Tuple, Union +import sys +import pandas as pd +import jiwer + +# -*- coding: utf-8 -*- + +def levenshtein(u, v): + prev = None + curr = [0] + list(range(1, len(v) + 1)) + # Operations: (SUB, DEL, INS) + prev_ops = None + curr_ops = [(0, 0, i) for i in range(len(v) + 1)] + for x in range(1, len(u) + 1): + prev, curr = curr, [x] + ([None] * len(v)) + prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v)) + for y in range(1, len(v) + 1): + delcost = prev[y] + 1 + addcost = curr[y - 1] + 1 + subcost = prev[y - 1] + int(u[x - 1] != v[y - 1]) + curr[y] = min(subcost, delcost, addcost) + if curr[y] == subcost: + (n_s, n_d, n_i) = prev_ops[y - 1] + curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i) + elif curr[y] == delcost: + (n_s, n_d, n_i) = prev_ops[y] + curr_ops[y] = (n_s, n_d + 1, n_i) + else: + (n_s, n_d, n_i) = curr_ops[y - 1] + curr_ops[y] = (n_s, n_d, n_i + 1) + return curr[len(v)], curr_ops[len(v)] + + + +def get_unicode_code(text): + result = ''.join( char if ord(char) < 128 else '\\u'+format(ord(char), 'x') for char in text ) + return result + + +def _measure_cer( + reference : str, transcription : str +) -> Tuple[int, int, int, int]: + """ + 소스 단어를 대상 단아로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다. + hints 횟수는 소스 딘아의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다. + + :param transcription: 대상 단어로 변환할 소스 문자열 + :param reference: 소스 단어 + :return: a tuple of #hits, #substitutions, #deletions, #insertions + """ + + ref, hyp = [], [] + + ref.append(reference) + hyp.append(transcription) + + #print("? : ", ref) + + cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0 + sen_err = 0 + + for n in range(len(ref)): + # update CER statistics + _, (s, i, d) = levenshtein(hyp[n], ref[n]) + cer_s += s + cer_i += i + cer_d += d + cer_n += len(ref[n]) + + # update SER statistics + if s + i + d > 0: + sen_err += 1 + + + + ''' + print("reference : ",reference) + print("cer S : ", cer_s) + print("cer I : ", cer_i) + print("cer D : ", cer_d) + print("cer_n : ", cer_n) + + + if cer_n > 0: + print('CER: %g%%, SER: %g%%' % ( + (100.0 * (cer_s + cer_i + cer_d)) / cer_n, + (100.0 * sen_err) / len(ref))) + ''' + substitutions = cer_s + deletions = cer_d + insertions = cer_i + hits = len(reference) - (substitutions + deletions) #correct characters + + return hits, substitutions, deletions, insertions + +def _measure_wer( + reference : str, transcription : str +) -> Tuple[int, int, int, int]: + """ + 소스 문자열을 대상 문자열로 변환하는 데 필요한 편집 작업(삭제, 삽입, 바꾸기)의 수를 확인합니다. + hints 횟수는 소스 문자열의 전체 길이에서 삭제 및 대체 횟수를 빼서 제공할 수 있습니다. + + :param transcription: 대상 단어 + :param reference: 소스 단어 + :return: a tuple of #hits, #substitutions, #deletions, #insertions + """ + + ref, hyp = [], [] + + ref.append(reference) + hyp.append(transcription) + + #print("? : ", ref) + + wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0 + sen_err = 0 + + for n in range(len(ref)): + # update WER statistics + _, (s, i, d) = levenshtein(hyp[n].split(), ref[n].split()) + wer_s += s + wer_i += i + wer_d += d + wer_n += len(ref[n].split()) + # update SER statistics + if s + i + d > 0: + sen_err += 1 + + + + #print("reference : ",reference) + #print("reference cnt : ", reference.split()) + #print("wer S : ", wer_s) + #print("wer I : ", wer_i) + #print("wer D : ", wer_d) + #print("wer_n : ", wer_n) + + + if wer_n > 0: + print('WER: %g%%, SER: %g%%' % ( + (100.0 * (wer_s + wer_i + wer_d)) / wer_n, + (100.0 * sen_err) / len(ref))) + + substitutions = wer_s + deletions = wer_d + insertions = wer_i + hits = len(reference.split()) - (substitutions + deletions) #correct words between refs and trans + + return hits, substitutions, deletions, insertions + + + + +def _measure_er( + reference : str, transcription : str +) -> Tuple[int, int]: + """ + TBD + :param transcription: 대상 문자열로 변환할 소스 문자열 + :param reference: + :return: a tuple of # + """ + TBD1 ="" + TBD2 ="" + return TBD1, TBD2 + + +def get_cer(reference, transcription, rm_punctuation = True + ) -> Tuple[int, int, int, int]: + + # 문자 오류율(CER)은 자동 음성 인식 시스템의 성능에 대한 일반적인 메트릭입니다. + # CER은 WER(단어 오류율)과 유사하지만 단어 대신 문자에 대해 작동합니다. + # 이 코드에서는 문제는 사람들이 띄어쓰기를 지키지 않고 작성한 텍스트를 컴퓨터가 정확하게 인식하는 것이 매우 어렵기 때문에 인식에러에서 생략합니다. + # CER의 출력은 특히 삽입 수가 많은 경우 항상 0과 1 사이의 숫자가 아닙니다. 이 값은 종종 잘못 예측된 문자의 백분율과 연관됩니다. 값이 낮을수록 좋습니다. + # CER이 0인 ASR 시스템의 성능은 완벽한 점수입니다. + + # CER = (S + D + I) / N = (S + D + I) / (S + D + C) + # S is the number of the substitutions, + # D is the number of the deletions, + # I is the number of the insertions, + # C is the number of the correct characters, + # N is the number of the characters in the reference (N=S+D+C). + + refs = jiwer.RemoveWhiteSpace(replace_by_space=False)(reference) + trans = jiwer.RemoveWhiteSpace(replace_by_space=False)(transcription) + + if rm_punctuation == True: + refs = jiwer.RemovePunctuation()(refs) + trans = jiwer.RemovePunctuation()(trans) + else: + refs = reference + trans = transcription + + #print("refs : ", refs) + + [hits ,cer_s, cer_d, cer_i] = _measure_cer(refs, trans) + + substitutions = cer_s + deletions = cer_d + insertions = cer_i + #print("tmp hits : ", hits) + incorrect = substitutions + deletions + insertions + total = substitutions + deletions + hits + insertions + + cer = incorrect / total + return cer, substitutions, deletions, insertions + + +def get_wer(reference, transcription, rm_punctuation = True + )-> Tuple[int, int, int, int]: + + # WER = (S + D + I) / N = (S + D + I) / (S + D + C) + # S is the number of the substitutions, + # D is the number of the deletions, + # I is the number of the insertions, + # C is the number of the correct words, + # N is the number of the words in the reference (N=S+D+C). + if rm_punctuation == True: + refs = jiwer.RemovePunctuation()(reference) + trans = jiwer.RemovePunctuation()(transcription) + else: + refs = reference + trans = transcription + [hits, wer_s, wer_d, wer_i] = _measure_wer(refs, trans) + + substitutions = wer_s + deletions = wer_d + insertions = wer_i + #print("tmp hits : ", hits) + incorrect = substitutions + deletions + insertions + total = substitutions + deletions + hits + insertions + + wer = incorrect / total + return wer, substitutions, deletions, insertions