""" Calculate WER with Whisper-large-v3 or Paraformer models, following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval """ import argparse import os import string import numpy as np import scipy import soundfile as sf import torch import zhconv from funasr import AutoModel from jiwer import compute_measures from tqdm import tqdm from transformers import WhisperForConditionalGeneration, WhisperProcessor from zhon.hanzi import punctuation def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--wav-path", type=str, help="path of the speech directory") parser.add_argument( "--decode-path", type=str, default=None, help="path of the output file of WER information", ) parser.add_argument( "--model-path", type=str, default=None, help="path of the local whisper and paraformer model, " "e.g., whisper: model/huggingface/whisper-large-v3/, " "paraformer: model/huggingface/paraformer-zh/", ) parser.add_argument( "--test-list", type=str, default="test.tsv", help="path of the transcript tsv file, where the first column " "is the wav name and the last column is the transcript", ) parser.add_argument("--lang", type=str, help="decoded language, zh or en") return parser def load_en_model(model_path): if model_path is None: model_path = "openai/whisper-large-v3" processor = WhisperProcessor.from_pretrained(model_path) model = WhisperForConditionalGeneration.from_pretrained(model_path) return processor, model def load_zh_model(model_path): if model_path is None: model_path = "paraformer-zh" model = AutoModel(model=model_path) return model def process_one(hypo, truth, lang): punctuation_all = punctuation + string.punctuation for x in punctuation_all: if x == "'": continue truth = truth.replace(x, "") hypo = hypo.replace(x, "") truth = truth.replace(" ", " ") hypo = hypo.replace(" ", " ") if lang == "zh": truth = " ".join([x for x in truth]) hypo = " ".join([x for x in hypo]) elif lang == "en": truth = truth.lower() hypo = hypo.lower() else: raise NotImplementedError measures = compute_measures(truth, hypo) word_num = len(truth.split(" ")) wer = measures["wer"] subs = measures["substitutions"] dele = measures["deletions"] inse = measures["insertions"] return (truth, hypo, wer, subs, dele, inse, word_num) def main(test_list, wav_path, model_path, decode_path, lang, device): if lang == "en": processor, model = load_en_model(model_path) model.to(device) elif lang == "zh": model = load_zh_model(model_path) params = [] for line in open(test_list).readlines(): line = line.strip() items = line.split("\t") wav_name, text_ref = items[0], items[-1] file_path = os.path.join(wav_path, wav_name + ".wav") assert os.path.exists(file_path), f"{file_path}" params.append((file_path, text_ref)) wers = [] inses = [] deles = [] subses = [] word_nums = 0 if decode_path: decode_dir = os.path.dirname(decode_path) if not os.path.exists(decode_dir): os.makedirs(decode_dir) fout = open(decode_path, "w") for wav_path, text_ref in tqdm(params): if lang == "en": wav, sr = sf.read(wav_path) if sr != 16000: wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) input_features = processor( wav, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(device) forced_decoder_ids = processor.get_decoder_prompt_ids( language="english", task="transcribe" ) predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids ) transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] elif lang == "zh": res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True) transcription = res[0]["text"] transcription = zhconv.convert(transcription, "zh-cn") truth, hypo, wer, subs, dele, inse, word_num = process_one( transcription, text_ref, lang ) if decode_path: fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") wers.append(float(wer)) inses.append(float(inse)) deles.append(float(dele)) subses.append(float(subs)) word_nums += word_num wer_avg = round(np.mean(wers) * 100, 3) wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) subs = round(np.mean(subses) * 100, 3) dele = round(np.mean(deles) * 100, 3) inse = round(np.mean(inses) * 100, 3) print(f"Seed-TTS WER: {wer_avg}%\n") print(f"WER: {wer}%\n") if decode_path: fout.write(f"SeedTTS WER: {wer_avg}%\n") fout.write(f"WER: {wer}%\n") fout.flush() if __name__ == "__main__": parser = get_parser() args = parser.parse_args() if torch.cuda.is_available(): device = torch.device("cuda", 0) else: device = torch.device("cpu") main( args.test_list, args.wav_path, args.model_path, args.decode_path, args.lang, device, )