icefall/egs/speechio/ASR/local/normalize_results.py

166 lines
5.5 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2024 Author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file uses speech io offcial pipline to normalize the decoding results.
https://github.com/SpeechColab/Leaderboard/blob/master/utils/textnorm_zh.py
Usage:
python normalize_results.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
"""
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import kaldialign
from speechio_norm import TextNorm
from icefall.utils import store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-log-dir",
type=str,
default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
)
parser.add_argument(
"--output-log-dir",
type=str,
default="./results_whisper_norm",
help="The directory to store the normalized whisper logs",
)
return parser
def save_results_with_speechio_text_norm(
res_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
normalizer = TextNorm()
# normlize items in results_dict
for key, results in results_dict.items():
results_norm = []
for item in results:
wav_name, ref, hyp = item
ref = normalizer(ref)
hyp = normalizer(hyp)
results_norm.append((wav_name, ref, hyp))
results_dict[key] = results_norm
test_set_wers = dict()
suffix = "epoch-999-avg-1"
for key, results in results_dict.items():
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
print(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
print("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
print(s)
def extract_hyp_ref_wavname(filename):
"""
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
"""
hyps, refs, wav_name = [], [], []
with open(filename, "r") as f:
for line in f:
if "ref" in line:
ref = line.split("ref=")[1].strip()
if ref[0] == "[":
ref = ref[2:-2]
list_elements = ref.split("', '")
ref = "".join(list_elements)
refs.append(ref)
elif "hyp" in line:
hyp = line.split("hyp=")[1].strip()
hyps.append(hyp)
wav_name.append(line.split(":")[0])
return hyps, refs, wav_name
def get_filenames(
whisper_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1",
):
results = []
start_index, end_index = 0, 26
dataset_parts = []
for i in range(start_index, end_index + 1):
idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
results.append(whisper_filename)
return results
def main():
parser = get_parser()
args = parser.parse_args()
# mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
filenames = get_filenames(args.model_log_dir)
for filename in filenames:
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
partition_name = filename.split("/")[-1].split("-")[1]
save_results_with_speechio_text_norm(
Path(args.output_log_dir),
partition_name,
{"norm": list(zip(wav_name, refs, hyps))},
)
print(f"Processed {partition_name}")
if __name__ == "__main__":
main()