icefall/egs/must_c/ST/local/preprocess_must_c.py

97 lines
2.8 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script normalizes transcripts from supervisions.
Usage:
./local/preprocess_must_c.py \
--manifest-dir ./data/manifests/v1.0/ \
--tgt-lang de
"""
import argparse
import logging
import re
from functools import partial
from pathlib import Path
from lhotse.recipes.utils import read_manifests_if_cached
from normalize_punctuation import normalize_punctuation
from remove_non_native_characters import remove_non_native_characters
from remove_punctuation import remove_punctuation
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
required=True,
help="Manifest directory",
)
parser.add_argument(
"--tgt-lang",
type=str,
required=True,
help="Target language, e.g., zh, de, fr.",
)
return parser.parse_args()
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
remove_non_native_characters_lang = partial(
remove_non_native_characters, lang=tgt_lang
)
prefix = "must_c"
suffix = "jsonl.gz"
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
for p in parts:
logging.info(f"Processing {p}")
name = f"en-{tgt_lang}_{p}"
# norm: normalization
# rm: remove punctuation
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
if dst_name.is_file():
logging.info(f"{dst_name} exists - skipping")
continue
manifests = read_manifests_if_cached(
dataset_parts=name,
output_dir=manifest_dir,
prefix=prefix,
suffix=suffix,
types=("supervisions",),
)
if name not in manifests:
raise RuntimeError(f"Processing {p} failed.")
supervisions = manifests[name]["supervisions"]
supervisions = supervisions.transform_text(normalize_punctuation_lang)
supervisions = supervisions.transform_text(remove_punctuation)
supervisions = supervisions.transform_text(lambda x: x.lower())
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
supervisions.to_file(dst_name)
def main():
args = get_args()
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
logging.info(vars(args))
assert args.manifest_dir.is_dir(), args.manifest_dir
preprocess_must_c(
manifest_dir=args.manifest_dir,
tgt_lang=args.tgt_lang,
)
if __name__ == "__main__":
main()