mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
166 lines
5.5 KiB
Python
Executable File
166 lines
5.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2024 Author: Yuekai Zhang
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This file uses speech io offcial pipline to normalize the decoding results.
|
|
https://github.com/SpeechColab/Leaderboard/blob/master/utils/textnorm_zh.py
|
|
|
|
Usage:
|
|
python normalize_results.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import kaldialign
|
|
from speechio_norm import TextNorm
|
|
|
|
from icefall.utils import store_transcripts, write_error_stats
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
"--model-log-dir",
|
|
type=str,
|
|
default="./recogs_whisper",
|
|
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
|
)
|
|
parser.add_argument(
|
|
"--output-log-dir",
|
|
type=str,
|
|
default="./results_whisper_norm",
|
|
help="The directory to store the normalized whisper logs",
|
|
)
|
|
return parser
|
|
|
|
|
|
def save_results_with_speechio_text_norm(
|
|
res_dir: Path,
|
|
test_set_name: str,
|
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
|
):
|
|
normalizer = TextNorm()
|
|
# normlize items in results_dict
|
|
for key, results in results_dict.items():
|
|
results_norm = []
|
|
for item in results:
|
|
wav_name, ref, hyp = item
|
|
ref = normalizer(ref)
|
|
hyp = normalizer(hyp)
|
|
results_norm.append((wav_name, ref, hyp))
|
|
results_dict[key] = results_norm
|
|
|
|
test_set_wers = dict()
|
|
|
|
suffix = "epoch-999-avg-1"
|
|
|
|
for key, results in results_dict.items():
|
|
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
|
|
results = sorted(results)
|
|
store_transcripts(filename=recog_path, texts=results)
|
|
print(f"The transcripts are stored in {recog_path}")
|
|
|
|
# The following prints out WERs, per-word error statistics and aligned
|
|
# ref/hyp pairs.
|
|
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
|
|
with open(errs_filename, "w") as f:
|
|
wer = write_error_stats(
|
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
|
)
|
|
test_set_wers[key] = wer
|
|
|
|
print("Wrote detailed error stats to {}".format(errs_filename))
|
|
|
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
|
|
with open(errs_info, "w") as f:
|
|
print("settings\tWER", file=f)
|
|
for key, val in test_set_wers:
|
|
print("{}\t{}".format(key, val), file=f)
|
|
|
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
note = "\tbest for {}".format(test_set_name)
|
|
for key, val in test_set_wers:
|
|
s += "{}\t{}{}\n".format(key, val, note)
|
|
note = ""
|
|
print(s)
|
|
|
|
|
|
def extract_hyp_ref_wavname(filename):
|
|
"""
|
|
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
|
|
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
|
|
"""
|
|
hyps, refs, wav_name = [], [], []
|
|
with open(filename, "r") as f:
|
|
for line in f:
|
|
if "ref" in line:
|
|
ref = line.split("ref=")[1].strip()
|
|
if ref[0] == "[":
|
|
ref = ref[2:-2]
|
|
list_elements = ref.split("', '")
|
|
ref = "".join(list_elements)
|
|
refs.append(ref)
|
|
elif "hyp" in line:
|
|
hyp = line.split("hyp=")[1].strip()
|
|
hyps.append(hyp)
|
|
wav_name.append(line.split(":")[0])
|
|
return hyps, refs, wav_name
|
|
|
|
|
|
def get_filenames(
|
|
whisper_log_dir,
|
|
whisper_suffix="beam-search-epoch-999-avg-1",
|
|
):
|
|
results = []
|
|
start_index, end_index = 0, 26
|
|
dataset_parts = []
|
|
for i in range(start_index, end_index + 1):
|
|
idx = f"{i}".zfill(2)
|
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
|
for partition in dataset_parts:
|
|
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
|
results.append(whisper_filename)
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
# mkdir output_log_dir
|
|
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
|
filenames = get_filenames(args.model_log_dir)
|
|
for filename in filenames:
|
|
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
|
|
partition_name = filename.split("/")[-1].split("-")[1]
|
|
|
|
save_results_with_speechio_text_norm(
|
|
Path(args.output_log_dir),
|
|
partition_name,
|
|
{"norm": list(zip(wav_name, refs, hyps))},
|
|
)
|
|
|
|
print(f"Processed {partition_name}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|