mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
97 lines
2.8 KiB
Python
Executable File
97 lines
2.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
"""
|
|
This script normalizes transcripts from supervisions.
|
|
|
|
Usage:
|
|
./local/preprocess_must_c.py \
|
|
--manifest-dir ./data/manifests/v1.0/ \
|
|
--tgt-lang de
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import re
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
from lhotse.recipes.utils import read_manifests_if_cached
|
|
from normalize_punctuation import normalize_punctuation
|
|
from remove_non_native_characters import remove_non_native_characters
|
|
from remove_punctuation import remove_punctuation
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--manifest-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Manifest directory",
|
|
)
|
|
parser.add_argument(
|
|
"--tgt-lang",
|
|
type=str,
|
|
required=True,
|
|
help="Target language, e.g., zh, de, fr.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
|
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
|
remove_non_native_characters_lang = partial(
|
|
remove_non_native_characters, lang=tgt_lang
|
|
)
|
|
|
|
prefix = "must_c"
|
|
suffix = "jsonl.gz"
|
|
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
|
for p in parts:
|
|
logging.info(f"Processing {p}")
|
|
name = f"en-{tgt_lang}_{p}"
|
|
|
|
# norm: normalization
|
|
# rm: remove punctuation
|
|
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
|
|
if dst_name.is_file():
|
|
logging.info(f"{dst_name} exists - skipping")
|
|
continue
|
|
|
|
manifests = read_manifests_if_cached(
|
|
dataset_parts=name,
|
|
output_dir=manifest_dir,
|
|
prefix=prefix,
|
|
suffix=suffix,
|
|
types=("supervisions",),
|
|
)
|
|
if name not in manifests:
|
|
raise RuntimeError(f"Processing {p} failed.")
|
|
|
|
supervisions = manifests[name]["supervisions"]
|
|
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
|
supervisions = supervisions.transform_text(remove_punctuation)
|
|
supervisions = supervisions.transform_text(lambda x: x.lower())
|
|
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
|
|
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
|
|
|
|
supervisions.to_file(dst_name)
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
|
|
logging.info(vars(args))
|
|
assert args.manifest_dir.is_dir(), args.manifest_dir
|
|
|
|
preprocess_must_c(
|
|
manifest_dir=args.manifest_dir,
|
|
tgt_lang=args.tgt_lang,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|