icefall/egs/librispeech/WSASR/local/make_error_cutset.py
2023-09-29 07:52:46 +08:00

176 lines
4.4 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins University (author: Dongji Gao)
import argparse
import random
from pathlib import Path
from typing import List
from lhotse import CutSet, load_manifest
from lhotse.cut.base import Cut
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-cutset",
type=str,
help="Supervision manifest that contains verbatim transcript",
)
parser.add_argument(
"--words-file",
type=str,
help="words.txt file",
)
parser.add_argument(
"--otc-token",
type=str,
help="OTC token in words.txt",
)
parser.add_argument(
"--sub-error-rate",
type=float,
default=0.0,
help="Substitution error rate",
)
parser.add_argument(
"--ins-error-rate",
type=float,
default=0.0,
help="Insertion error rate",
)
parser.add_argument(
"--del-error-rate",
type=float,
default=0.0,
help="Deletion error rate",
)
parser.add_argument(
"--output-cutset",
type=str,
default="",
help="Supervision manifest that contains modified non-verbatim transcript",
)
parser.add_argument("--verbose", type=str2bool, help="show details of errors")
return parser.parse_args()
def check_args(args):
total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate
assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0
assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0
assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0
assert total_error_rate <= 1.0
def get_word_list(token_path: str) -> List:
word_list = []
with open(Path(token_path), "r") as tp:
for line in tp.readlines():
token = line.split()[0]
assert token not in word_list
word_list.append(token)
return word_list
def modify_cut_text(
cut: Cut,
words_list: List,
non_words: List,
sub_ratio: float = 0.0,
ins_ratio: float = 0.0,
del_ratio: float = 0.0,
):
text = cut.supervisions[0].text
text_list = text.split()
# We save the modified information of the original verbatim text for debugging
marked_verbatim_text_list = []
modified_text_list = []
del_index_set = set()
sub_index_set = set()
ins_index_set = set()
# We follow the order: deletion -> substitution -> insertion
for token in text_list:
marked_token = token
modified_token = token
prob = random.random()
if prob <= del_ratio:
marked_token = f"-{token}-"
modified_token = ""
elif prob <= del_ratio + sub_ratio + ins_ratio:
if prob <= del_ratio + sub_ratio:
marked_token = f"[{token}]"
else:
marked_verbatim_text_list.append(marked_token)
modified_text_list.append(modified_token)
marked_token = "[]"
# get new_token
while (
modified_token == token
or modified_token in non_words
or modified_token.startswith("#")
):
modified_token = random.choice(words_list)
marked_verbatim_text_list.append(marked_token)
modified_text_list.append(modified_token)
marked_text = " ".join(marked_verbatim_text_list)
modified_text = " ".join(modified_text_list)
if not hasattr(cut.supervisions[0], "verbatim_text"):
cut.supervisions[0].verbatim_text = marked_text
cut.supervisions[0].text = modified_text
return cut
def main():
args = get_args()
check_args(args)
otc_token = args.otc_token
non_words = set(("sil", "<UNK>", "<eps>"))
non_words.add(otc_token)
words_list = get_word_list(args.words_file)
cutset = load_manifest(Path(args.input_cutset))
cuts = []
for cut in cutset:
modified_cut = modify_cut_text(
cut=cut,
words_list=words_list,
non_words=non_words,
sub_ratio=args.sub_error_rate,
ins_ratio=args.ins_error_rate,
del_ratio=args.del_error_rate,
)
cuts.append(modified_cut)
output_cutset = CutSet.from_cuts(cuts)
output_cutset.to_file(args.output_cutset)
if __name__ == "__main__":
main()