mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
"""
|
|
Calculate WER with Whisper-large-v3 or Paraformer models,
|
|
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import string
|
|
|
|
import numpy as np
|
|
import scipy
|
|
import soundfile as sf
|
|
import torch
|
|
import zhconv
|
|
from funasr import AutoModel
|
|
from jiwer import compute_measures
|
|
from tqdm import tqdm
|
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
|
from zhon.hanzi import punctuation
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
|
parser.add_argument(
|
|
"--decode-path",
|
|
type=str,
|
|
default=None,
|
|
help="path of the output file of WER information",
|
|
)
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
default=None,
|
|
help="path of the local whisper and paraformer model, "
|
|
"e.g., whisper: model/huggingface/whisper-large-v3/, "
|
|
"paraformer: model/huggingface/paraformer-zh/",
|
|
)
|
|
parser.add_argument(
|
|
"--test-list",
|
|
type=str,
|
|
default="test.tsv",
|
|
help="path of the transcript tsv file, where the first column "
|
|
"is the wav name and the last column is the transcript",
|
|
)
|
|
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
|
|
return parser
|
|
|
|
|
|
def load_en_model(model_path):
|
|
if model_path is None:
|
|
model_path = "openai/whisper-large-v3"
|
|
processor = WhisperProcessor.from_pretrained(model_path)
|
|
model = WhisperForConditionalGeneration.from_pretrained(model_path)
|
|
return processor, model
|
|
|
|
|
|
def load_zh_model(model_path):
|
|
if model_path is None:
|
|
model_path = "paraformer-zh"
|
|
model = AutoModel(model=model_path)
|
|
return model
|
|
|
|
|
|
def process_one(hypo, truth, lang):
|
|
punctuation_all = punctuation + string.punctuation
|
|
for x in punctuation_all:
|
|
if x == "'":
|
|
continue
|
|
truth = truth.replace(x, "")
|
|
hypo = hypo.replace(x, "")
|
|
|
|
truth = truth.replace(" ", " ")
|
|
hypo = hypo.replace(" ", " ")
|
|
|
|
if lang == "zh":
|
|
truth = " ".join([x for x in truth])
|
|
hypo = " ".join([x for x in hypo])
|
|
elif lang == "en":
|
|
truth = truth.lower()
|
|
hypo = hypo.lower()
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
measures = compute_measures(truth, hypo)
|
|
word_num = len(truth.split(" "))
|
|
wer = measures["wer"]
|
|
subs = measures["substitutions"]
|
|
dele = measures["deletions"]
|
|
inse = measures["insertions"]
|
|
return (truth, hypo, wer, subs, dele, inse, word_num)
|
|
|
|
|
|
def main(test_list, wav_path, model_path, decode_path, lang, device):
|
|
if lang == "en":
|
|
processor, model = load_en_model(model_path)
|
|
model.to(device)
|
|
elif lang == "zh":
|
|
model = load_zh_model(model_path)
|
|
params = []
|
|
for line in open(test_list).readlines():
|
|
line = line.strip()
|
|
items = line.split("\t")
|
|
wav_name, text_ref = items[0], items[-1]
|
|
file_path = os.path.join(wav_path, wav_name + ".wav")
|
|
assert os.path.exists(file_path), f"{file_path}"
|
|
|
|
params.append((file_path, text_ref))
|
|
wers = []
|
|
inses = []
|
|
deles = []
|
|
subses = []
|
|
word_nums = 0
|
|
if decode_path:
|
|
decode_dir = os.path.dirname(decode_path)
|
|
if not os.path.exists(decode_dir):
|
|
os.makedirs(decode_dir)
|
|
fout = open(decode_path, "w")
|
|
for wav_path, text_ref in tqdm(params):
|
|
if lang == "en":
|
|
wav, sr = sf.read(wav_path)
|
|
if sr != 16000:
|
|
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
|
|
input_features = processor(
|
|
wav, sampling_rate=16000, return_tensors="pt"
|
|
).input_features
|
|
input_features = input_features.to(device)
|
|
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
|
language="english", task="transcribe"
|
|
)
|
|
predicted_ids = model.generate(
|
|
input_features, forced_decoder_ids=forced_decoder_ids
|
|
)
|
|
transcription = processor.batch_decode(
|
|
predicted_ids, skip_special_tokens=True
|
|
)[0]
|
|
elif lang == "zh":
|
|
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
|
|
transcription = res[0]["text"]
|
|
transcription = zhconv.convert(transcription, "zh-cn")
|
|
|
|
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
|
transcription, text_ref, lang
|
|
)
|
|
if decode_path:
|
|
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
|
wers.append(float(wer))
|
|
inses.append(float(inse))
|
|
deles.append(float(dele))
|
|
subses.append(float(subs))
|
|
word_nums += word_num
|
|
|
|
wer_avg = round(np.mean(wers) * 100, 3)
|
|
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
|
subs = round(np.mean(subses) * 100, 3)
|
|
dele = round(np.mean(deles) * 100, 3)
|
|
inse = round(np.mean(inses) * 100, 3)
|
|
print(f"Seed-TTS WER: {wer_avg}%\n")
|
|
print(f"WER: {wer}%\n")
|
|
if decode_path:
|
|
fout.write(f"SeedTTS WER: {wer_avg}%\n")
|
|
fout.write(f"WER: {wer}%\n")
|
|
fout.flush()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
else:
|
|
device = torch.device("cpu")
|
|
main(
|
|
args.test_list,
|
|
args.wav_path,
|
|
args.model_path,
|
|
args.decode_path,
|
|
args.lang,
|
|
device,
|
|
)
|