mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
173 lines
4.8 KiB
Python
173 lines
4.8 KiB
Python
"""
|
||
Calculate WER with Hubert models.
|
||
"""
|
||
import argparse
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
|
||
import librosa
|
||
import numpy as np
|
||
import soundfile as sf
|
||
import torch
|
||
from jiwer import compute_measures
|
||
from tqdm import tqdm
|
||
from transformers import pipeline
|
||
|
||
|
||
def get_parser():
|
||
parser = argparse.ArgumentParser()
|
||
|
||
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||
parser.add_argument(
|
||
"--decode-path",
|
||
type=str,
|
||
default=None,
|
||
help="path of the output file of WER information",
|
||
)
|
||
parser.add_argument(
|
||
"--model-path",
|
||
type=str,
|
||
default=None,
|
||
help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
|
||
)
|
||
parser.add_argument(
|
||
"--test-list",
|
||
type=str,
|
||
default="test.tsv",
|
||
help="path of the transcript tsv file, where the first column "
|
||
"is the wav name and the last column is the transcript",
|
||
)
|
||
parser.add_argument(
|
||
"--batch-size", type=int, default=16, help="decoding batch size"
|
||
)
|
||
return parser
|
||
|
||
|
||
def post_process(text: str):
|
||
text = text.replace("‘", "'")
|
||
text = text.replace("’", "'")
|
||
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||
text = re.sub(r"\s+", " ", text)
|
||
text = text.strip()
|
||
return text
|
||
|
||
|
||
def process_one(hypo, truth):
|
||
truth = post_process(truth)
|
||
hypo = post_process(hypo)
|
||
|
||
measures = compute_measures(truth, hypo)
|
||
word_num = len(truth.split(" "))
|
||
wer = measures["wer"]
|
||
subs = measures["substitutions"]
|
||
dele = measures["deletions"]
|
||
inse = measures["insertions"]
|
||
return (truth, hypo, wer, subs, dele, inse, word_num)
|
||
|
||
|
||
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||
def __init__(self, wav_path: str, test_list: str):
|
||
super().__init__()
|
||
self.wav_name = []
|
||
self.wav_paths = []
|
||
self.transcripts = []
|
||
with Path(test_list).open("r", encoding="utf8") as f:
|
||
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||
for item in meta:
|
||
self.wav_name.append(item[0])
|
||
self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
|
||
self.transcripts.append(item[-1])
|
||
|
||
def __len__(self):
|
||
return len(self.wav_paths)
|
||
|
||
def __getitem__(self, index: int):
|
||
wav, sampling_rate = sf.read(self.wav_paths[index])
|
||
item = {
|
||
"array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
|
||
"sampling_rate": 16000,
|
||
"reference": self.transcripts[index],
|
||
"wav_name": self.wav_name[index],
|
||
}
|
||
return item
|
||
|
||
|
||
def main(test_list, wav_path, model_path, decode_path, batch_size, device):
|
||
|
||
if model_path is not None:
|
||
pipe = pipeline(
|
||
"automatic-speech-recognition",
|
||
model=model_path,
|
||
device=device,
|
||
tokenizer=model_path,
|
||
)
|
||
else:
|
||
pipe = pipeline(
|
||
"automatic-speech-recognition",
|
||
model="facebook/hubert-large-ls960-ft",
|
||
device=device,
|
||
)
|
||
|
||
dataset = SpeechEvalDataset(wav_path, test_list)
|
||
|
||
bar = tqdm(
|
||
pipe(
|
||
dataset,
|
||
generate_kwargs={"language": "english", "task": "transcribe"},
|
||
batch_size=batch_size,
|
||
),
|
||
total=len(dataset),
|
||
)
|
||
|
||
wers = []
|
||
inses = []
|
||
deles = []
|
||
subses = []
|
||
word_nums = 0
|
||
if decode_path:
|
||
decode_dir = os.path.dirname(decode_path)
|
||
if not os.path.exists(decode_dir):
|
||
os.makedirs(decode_dir)
|
||
fout = open(decode_path, "w")
|
||
for out in bar:
|
||
wav_name = out["wav_name"][0]
|
||
transcription = post_process(out["text"].strip())
|
||
text_ref = post_process(out["reference"][0].strip())
|
||
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
||
transcription, text_ref
|
||
)
|
||
if decode_path:
|
||
fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
||
wers.append(float(wer))
|
||
inses.append(float(inse))
|
||
deles.append(float(dele))
|
||
subses.append(float(subs))
|
||
word_nums += word_num
|
||
|
||
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
||
subs = round(np.mean(subses) * 100, 3)
|
||
dele = round(np.mean(deles) * 100, 3)
|
||
inse = round(np.mean(inses) * 100, 3)
|
||
print(f"WER: {wer}%\n")
|
||
if decode_path:
|
||
fout.write(f"WER: {wer}%\n")
|
||
fout.flush()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = get_parser()
|
||
args = parser.parse_args()
|
||
if torch.cuda.is_available():
|
||
device = torch.device("cuda", 0)
|
||
else:
|
||
device = torch.device("cpu")
|
||
main(
|
||
args.test_list,
|
||
args.wav_path,
|
||
args.model_path,
|
||
args.decode_path,
|
||
args.batch_size,
|
||
device,
|
||
)
|