icefall/egs/speechio/ASR/local/normalize_results.py
Yuekai Zhang 6d7c1d13a5
update speechio whisper ft results (#1605)
* update speechio whisper ft results
2024-04-30 11:49:20 +08:00

166 lines
5.6 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2024 Author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file uses whisper and zipformer decoding results to generate fusion decoding results.
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
Usage:
python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
"""
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import kaldialign
from speechio_norm import TextNorm
from icefall.utils import store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-log-dir",
type=str,
default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
)
parser.add_argument(
"--output-log-dir",
type=str,
default="./results_whisper_norm",
help="The directory to store the normalized whisper logs",
)
return parser
def save_results_with_speechio_text_norm(
res_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
normalizer = TextNorm()
# normlize items in results_dict
for key, results in results_dict.items():
results_norm = []
for item in results:
wav_name, ref, hyp = item
ref = normalizer(ref)
hyp = normalizer(hyp)
results_norm.append((wav_name, ref, hyp))
results_dict[key] = results_norm
test_set_wers = dict()
suffix = "epoch-999-avg-1"
for key, results in results_dict.items():
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
print(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
print("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
print(s)
def extract_hyp_ref_wavname(filename):
"""
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
"""
hyps, refs, wav_name = [], [], []
with open(filename, "r") as f:
for line in f:
if "ref" in line:
ref = line.split("ref=")[1].strip()
ref = ref[2:-2]
list_elements = ref.split("', '")
ref = "".join(list_elements)
refs.append(ref)
elif "hyp" in line:
hyp = line.split("hyp=")[1].strip()
hyps.append(hyp)
wav_name.append(line.split(":")[0])
return hyps, refs, wav_name
def get_filenames(
whisper_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1",
):
results = []
start_index, end_index = 0, 26
dataset_parts = []
for i in range(start_index, end_index + 1):
idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
results.append(whisper_filename)
return results
def main():
parser = get_parser()
args = parser.parse_args()
# mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
filenames = get_filenames(args.model_log_dir)
for filename in filenames:
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
partition_name = filename.split("/")[-1].split("-")[1]
save_results_with_speechio_text_norm(
Path(args.output_log_dir),
partition_name,
{"norm": list(zip(wav_name, refs, hyps))},
)
print(f"Processed {partition_name}")
if __name__ == "__main__":
main()