icefall/egs/zipvoice/local/evaluate_wer_seedtts.py
Wei Kang 06539d2b9d
Add Zipvoice (#1964)
* Add ZipVoice - a flow-matching based zero-shot TTS model.
2025-06-17 20:17:12 +08:00

182 lines
5.5 KiB
Python

"""
Calculate WER with Whisper-large-v3 or Paraformer models,
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
"""
import argparse
import os
import string
import numpy as np
import scipy
import soundfile as sf
import torch
import zhconv
from funasr import AutoModel
from jiwer import compute_measures
from tqdm import tqdm
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from zhon.hanzi import punctuation
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
parser.add_argument(
"--decode-path",
type=str,
default=None,
help="path of the output file of WER information",
)
parser.add_argument(
"--model-path",
type=str,
default=None,
help="path of the local whisper and paraformer model, "
"e.g., whisper: model/huggingface/whisper-large-v3/, "
"paraformer: model/huggingface/paraformer-zh/",
)
parser.add_argument(
"--test-list",
type=str,
default="test.tsv",
help="path of the transcript tsv file, where the first column "
"is the wav name and the last column is the transcript",
)
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
return parser
def load_en_model(model_path):
if model_path is None:
model_path = "openai/whisper-large-v3"
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
return processor, model
def load_zh_model(model_path):
if model_path is None:
model_path = "paraformer-zh"
model = AutoModel(model=model_path)
return model
def process_one(hypo, truth, lang):
punctuation_all = punctuation + string.punctuation
for x in punctuation_all:
if x == "'":
continue
truth = truth.replace(x, "")
hypo = hypo.replace(x, "")
truth = truth.replace(" ", " ")
hypo = hypo.replace(" ", " ")
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
else:
raise NotImplementedError
measures = compute_measures(truth, hypo)
word_num = len(truth.split(" "))
wer = measures["wer"]
subs = measures["substitutions"]
dele = measures["deletions"]
inse = measures["insertions"]
return (truth, hypo, wer, subs, dele, inse, word_num)
def main(test_list, wav_path, model_path, decode_path, lang, device):
if lang == "en":
processor, model = load_en_model(model_path)
model.to(device)
elif lang == "zh":
model = load_zh_model(model_path)
params = []
for line in open(test_list).readlines():
line = line.strip()
items = line.split("\t")
wav_name, text_ref = items[0], items[-1]
file_path = os.path.join(wav_path, wav_name + ".wav")
assert os.path.exists(file_path), f"{file_path}"
params.append((file_path, text_ref))
wers = []
inses = []
deles = []
subses = []
word_nums = 0
if decode_path:
decode_dir = os.path.dirname(decode_path)
if not os.path.exists(decode_dir):
os.makedirs(decode_dir)
fout = open(decode_path, "w")
for wav_path, text_ref in tqdm(params):
if lang == "en":
wav, sr = sf.read(wav_path)
if sr != 16000:
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
input_features = processor(
wav, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(
language="english", task="transcribe"
)
predicted_ids = model.generate(
input_features, forced_decoder_ids=forced_decoder_ids
)
transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
elif lang == "zh":
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
transcription = res[0]["text"]
transcription = zhconv.convert(transcription, "zh-cn")
truth, hypo, wer, subs, dele, inse, word_num = process_one(
transcription, text_ref, lang
)
if decode_path:
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
wers.append(float(wer))
inses.append(float(inse))
deles.append(float(dele))
subses.append(float(subs))
word_nums += word_num
wer_avg = round(np.mean(wers) * 100, 3)
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
subs = round(np.mean(subses) * 100, 3)
dele = round(np.mean(deles) * 100, 3)
inse = round(np.mean(inses) * 100, 3)
print(f"Seed-TTS WER: {wer_avg}%\n")
print(f"WER: {wer}%\n")
if decode_path:
fout.write(f"SeedTTS WER: {wer_avg}%\n")
fout.write(f"WER: {wer}%\n")
fout.flush()
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")
main(
args.test_list,
args.wav_path,
args.model_path,
args.decode_path,
args.lang,
device,
)