mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add data preparation for the MuST-C speech translation corpus (#1107)
This commit is contained in:
parent
ba257efbcd
commit
c0de78d3c0
@ -107,7 +107,7 @@ fi
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
|
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
155
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This file computes fbank features of the MuST-C dataset.
|
||||
It looks for manifests in the directory "in_dir" and write
|
||||
generated features to "out_dir".
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
FeatureSet,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--in-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Input manifest directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory where generated fbank features are saved.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of jobs for computing features",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to enable speed perturb with factors 0.9 and 1.1 on
|
||||
the train subset. False (by default) to disable speed perturb.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_must_c(
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
tgt_lang: str,
|
||||
num_jobs: int,
|
||||
perturb_speed: bool,
|
||||
):
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
|
||||
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
if perturb_speed and p == "train":
|
||||
cuts_path += "_sp"
|
||||
|
||||
cuts_path += ".jsonl.gz"
|
||||
|
||||
if Path(cuts_path).is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
|
||||
supervisions_filename = (
|
||||
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
|
||||
)
|
||||
assert recordings_filename.is_file(), recordings_filename
|
||||
assert supervisions_filename.is_file(), supervisions_filename
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=load_manifest(recordings_filename),
|
||||
supervisions=load_manifest(supervisions_filename),
|
||||
)
|
||||
if perturb_speed and p == "train":
|
||||
logging.info("Speed perturbing for the train dataset")
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
|
||||
else:
|
||||
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.in_dir.is_dir(), args.in_dir
|
||||
|
||||
compute_fbank_must_c(
|
||||
in_dir=args.in_dir,
|
||||
out_dir=args.out_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
num_jobs=args.num_jobs,
|
||||
perturb_speed=args.perturb_speed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
egs/must_c/ST/local/get_text.py
Executable file
34
egs/must_c/ST/local/get_text.py
Executable file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file prints the text field of supervisions from cutset to the console
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Input manifest",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.manifest.is_file(), args.manifest
|
||||
|
||||
cutset = load_manifest_lazy(args.manifest)
|
||||
for c in cutset:
|
||||
for sup in c.supervisions:
|
||||
print(sup.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
48
egs/must_c/ST/local/get_words.py
Executable file
48
egs/must_c/ST/local/get_words.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file generates words.txt from the given transcript file.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"transcript",
|
||||
type=Path,
|
||||
help="Input transcript file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert args.transcript.is_file(), args.transcript
|
||||
|
||||
word_set = set()
|
||||
with open(args.transcript) as f:
|
||||
for line in f:
|
||||
words = line.strip().split()
|
||||
for w in words:
|
||||
word_set.add(w)
|
||||
|
||||
# Note: reserved* should be kept in sync with ./local/prepare_lang_bpe.py
|
||||
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
|
||||
reserved2 = ["#0", "<s>", "</s>"]
|
||||
|
||||
for w in reserved1 + reserved2:
|
||||
assert w not in word_set, w
|
||||
|
||||
words = sorted(list(word_set))
|
||||
words = reserved1 + words + reserved2
|
||||
|
||||
for i, w in enumerate(words):
|
||||
print(w, i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
169
egs/must_c/ST/local/normalize_punctuation.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
|
||||
|
||||
def normalize_punctuation(s: str, lang: str) -> str:
|
||||
"""
|
||||
This function implements
|
||||
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/normalize-punctuation.perl
|
||||
|
||||
Args:
|
||||
s:
|
||||
A string to be normalized.
|
||||
lang:
|
||||
The language to which `s` belongs
|
||||
Returns:
|
||||
Return a normalized string.
|
||||
"""
|
||||
# s/\r//g;
|
||||
s = re.sub("\r", "", s)
|
||||
|
||||
# remove extra spaces
|
||||
# s/\(/ \(/g;
|
||||
s = re.sub("\(", " (", s) # add a space before (
|
||||
|
||||
# s/\)/\) /g; s/ +/ /g;
|
||||
s = re.sub("\)", ") ", s) # add a space after )
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
s = re.sub("\) ([\.\!\:\?\;\,])", r")\1", s)
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = re.sub("\( ", "(", s) # remove space after (
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = re.sub(" \)", ")", s) # remove space before )
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = re.sub("(\d) \%", r"\1%", s) # remove space between a digit and %
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s) # remove space before :
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s) # remove space before ;
|
||||
|
||||
# normalize unicode punctuation
|
||||
# s/\`/\'/g;
|
||||
s = re.sub("`", "'", s) # replace ` with '
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = re.sub("''", '"', s) # replace '' with "
|
||||
|
||||
# s/„/\"/g;
|
||||
s = re.sub("„", '"', s) # replace „ with "
|
||||
|
||||
# s/“/\"/g;
|
||||
s = re.sub("“", '"', s) # replace “ with "
|
||||
|
||||
# s/”/\"/g;
|
||||
s = re.sub("”", '"', s) # replace ” with "
|
||||
|
||||
# s/–/-/g;
|
||||
s = re.sub("–", "-", s) # replace – with -
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = re.sub("—", " - ", s)
|
||||
s = re.sub(" +", " ", s) # convert multiple spaces to one
|
||||
|
||||
# s/´/\'/g;
|
||||
s = re.sub("´", "'", s)
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])‘([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/([a-z])’([a-z])/$1\'$2/gi;
|
||||
s = re.sub("([a-z])’([a-z])", r"\1'\2", s, flags=re.IGNORECASE)
|
||||
|
||||
# s/‘/\'/g;
|
||||
s = re.sub("‘", "'", s)
|
||||
|
||||
# s/‚/\'/g;
|
||||
s = re.sub("‚", "'", s)
|
||||
|
||||
# s/’/\"/g;
|
||||
s = re.sub("’", '"', s)
|
||||
|
||||
# s/''/\"/g;
|
||||
s = re.sub("''", '"', s)
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = re.sub("´´", '"', s)
|
||||
|
||||
# s/…/.../g;
|
||||
s = re.sub("…", "...", s)
|
||||
|
||||
# French quotes
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = re.sub(" « ", ' "', s)
|
||||
|
||||
# s/« /\"/g;
|
||||
s = re.sub("« ", '"', s)
|
||||
|
||||
# s/«/\"/g;
|
||||
s = re.sub("«", '"', s)
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = re.sub(" » ", '" ', s)
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = re.sub(" »", '"', s)
|
||||
|
||||
# s/»/\"/g;
|
||||
s = re.sub("»", '"', s)
|
||||
|
||||
# handle pseudo-spaces
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = re.sub(" %", r"%", s)
|
||||
|
||||
# s/nº /nº /g;
|
||||
s = re.sub("nº ", "nº ", s)
|
||||
|
||||
# s/ :/:/g;
|
||||
s = re.sub(" :", ":", s)
|
||||
|
||||
# s/ ºC/ ºC/g;
|
||||
s = re.sub(" ºC", " ºC", s)
|
||||
|
||||
# s/ cm/ cm/g;
|
||||
s = re.sub(" cm", " cm", s)
|
||||
|
||||
# s/ \?/\?/g;
|
||||
s = re.sub(" \?", "\?", s)
|
||||
|
||||
# s/ \!/\!/g;
|
||||
s = re.sub(" \!", "\!", s)
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = re.sub(" ;", ";", s)
|
||||
|
||||
# s/, /, /g; s/ +/ /g;
|
||||
s = re.sub(", ", ", ", s)
|
||||
s = re.sub(" +", " ", s)
|
||||
|
||||
if lang == "en":
|
||||
# English "quotation," followed by comma, style
|
||||
# s/\"([,\.]+)/$1\"/g;
|
||||
s = re.sub('"([,\.]+)', r'\1"', s)
|
||||
elif lang in ("cs", "cz"):
|
||||
# Czech is confused
|
||||
pass
|
||||
else:
|
||||
# German/Spanish/French "quotation", followed by comma, style
|
||||
# s/,\"/\",/g;
|
||||
s = re.sub(',"', '",', s)
|
||||
|
||||
# s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
|
||||
s = re.sub('(\.+)"(\s*[^<])', r'"\1\2', s)
|
||||
|
||||
if lang in ("de", "es", "cz", "cs", "fr"):
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1,\2", s)
|
||||
else:
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = re.sub("(\d) (\d)", r"\1.\2", s)
|
||||
|
||||
return s
|
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang.py
|
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
96
egs/must_c/ST/local/preprocess_must_c.py
Executable file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
This script normalizes transcripts from supervisions.
|
||||
|
||||
Usage:
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/v1.0/ \
|
||||
--tgt-lang de
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Manifest directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-lang",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target language, e.g., zh, de, fr.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
||||
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
||||
remove_non_native_characters_lang = partial(
|
||||
remove_non_native_characters, lang=tgt_lang
|
||||
)
|
||||
|
||||
prefix = "must_c"
|
||||
suffix = "jsonl.gz"
|
||||
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||
for p in parts:
|
||||
logging.info(f"Processing {p}")
|
||||
name = f"en-{tgt_lang}_{p}"
|
||||
|
||||
# norm: normalization
|
||||
# rm: remove punctuation
|
||||
dst_name = manifest_dir / f"must_c_supervisions_{name}_norm_rm.jsonl.gz"
|
||||
if dst_name.is_file():
|
||||
logging.info(f"{dst_name} exists - skipping")
|
||||
continue
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=name,
|
||||
output_dir=manifest_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
types=("supervisions",),
|
||||
)
|
||||
if name not in manifests:
|
||||
raise RuntimeError(f"Processing {p} failed.")
|
||||
|
||||
supervisions = manifests[name]["supervisions"]
|
||||
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
||||
supervisions = supervisions.transform_text(remove_punctuation)
|
||||
supervisions = supervisions.transform_text(lambda x: x.lower())
|
||||
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
|
||||
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
|
||||
|
||||
supervisions.to_file(dst_name)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
logging.info(vars(args))
|
||||
assert args.manifest_dir.is_dir(), args.manifest_dir
|
||||
|
||||
preprocess_must_c(
|
||||
manifest_dir=args.manifest_dir,
|
||||
tgt_lang=args.tgt_lang,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
@ -0,0 +1,21 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def remove_non_native_characters(s: str, lang: str):
|
||||
if lang == "de":
|
||||
# ä -> ae
|
||||
# ö -> oe
|
||||
# ü -> ue
|
||||
# ß -> ss
|
||||
|
||||
s = re.sub("ä", "ae", s)
|
||||
s = re.sub("ö", "oe", s)
|
||||
s = re.sub("ü", "ue", s)
|
||||
s = re.sub("ß", "ss", s)
|
||||
# keep only a-z and spaces
|
||||
# note: ' is removed
|
||||
s = re.sub(r"[^a-z\s]", "", s)
|
||||
|
||||
return s
|
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
41
egs/must_c/ST/local/remove_punctuation.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
def remove_punctuation(s: str) -> str:
|
||||
"""
|
||||
It implements https://github.com/espnet/espnet/blob/master/utils/remove_punctuation.pl
|
||||
"""
|
||||
|
||||
# Remove punctuation except apostrophe
|
||||
# s/<space>/spacemark/g; # for scoring
|
||||
s = re.sub("<space>", "spacemark", s)
|
||||
|
||||
# s/'/apostrophe/g;
|
||||
s = re.sub("'", "apostrophe", s)
|
||||
|
||||
# s/[[:punct:]]//g;
|
||||
s = s.translate(str.maketrans("", "", string.punctuation))
|
||||
# string punctuation returns the following string
|
||||
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
|
||||
# See
|
||||
# https://stackoverflow.com/questions/265960/best-way-to-strip-punctuation-from-a-string
|
||||
|
||||
# s/apostrophe/'/g;
|
||||
s = re.sub("apostrophe", "'", s)
|
||||
|
||||
# s/spacemark/<space>/g; # for scoring
|
||||
s = re.sub("spacemark", "<space>", s)
|
||||
|
||||
# remove whitespace
|
||||
# s/\s+/ /g;
|
||||
s = re.sub("\s+", " ", s)
|
||||
|
||||
# s/^\s+//;
|
||||
s = re.sub("^\s+", "", s)
|
||||
|
||||
# s/\s+$//;
|
||||
s = re.sub("\s+$", "", s)
|
||||
|
||||
return s
|
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
197
egs/must_c/ST/local/test_normalize_punctuation.py
Executable file
@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from normalize_punctuation import normalize_punctuation
|
||||
|
||||
|
||||
def test_normalize_punctuation():
|
||||
# s/\r//g;
|
||||
s = "a\r\nb\r\n"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert "\r" not in n
|
||||
assert len(s) - 2 == len(n), (len(s), len(n))
|
||||
|
||||
# s/\(/ \(/g;
|
||||
s = "(ab (c"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " (ab (c", n
|
||||
|
||||
# s/\)/\) /g;
|
||||
s = "a)b c)"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a) b c) "
|
||||
|
||||
# s/ +/ /g;
|
||||
s = " a b c d "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == " a b c d "
|
||||
|
||||
# s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
||||
for i in ".!:?;,":
|
||||
s = f"a) {i}"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == f"a){i}"
|
||||
|
||||
# s/\( /\(/g;
|
||||
s = "a( b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a (b", n
|
||||
|
||||
# s/ \)/\)/g;
|
||||
s = "ab ) a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "ab) a", n
|
||||
|
||||
# s/(\d) \%/$1\%/g;
|
||||
s = "1 %a"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "1%a", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = "a :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a:", n
|
||||
|
||||
# s/ ;/;/g;
|
||||
s = "a ;"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a;", n
|
||||
|
||||
# s/\`/\'/g;
|
||||
s = "`a`"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "'a'", n
|
||||
|
||||
# s/\'\'/ \" /g;
|
||||
s = "''a''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/„/\"/g;
|
||||
s = '„a"'
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/“/\"/g;
|
||||
s = "“a„"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/”/\"/g;
|
||||
s = "“a”"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"a"', n
|
||||
|
||||
# s/–/-/g;
|
||||
s = "a–b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a-b", n
|
||||
|
||||
# s/—/ - /g; s/ +/ /g;
|
||||
s = "a—b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a - b", n
|
||||
|
||||
# s/´/\'/g;
|
||||
s = "a´b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/([a-z])‘([a-z])/$1\'$2/gi;
|
||||
for i in "‘’":
|
||||
s = f"a{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'B", n
|
||||
|
||||
s = f"A{i}B"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'B", n
|
||||
|
||||
s = f"A{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "A'b", n
|
||||
|
||||
# s/‘/\'/g;
|
||||
# s/‚/\'/g;
|
||||
for i in "‘‚":
|
||||
s = f"a{i}b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "a'b", n
|
||||
|
||||
# s/’/\"/g;
|
||||
s = "’"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/''/\"/g;
|
||||
s = "''"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/´´/\"/g;
|
||||
s = "´´"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/…/.../g;
|
||||
s = "…"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "...", n
|
||||
|
||||
# s/ « / \"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/« /\"/g;
|
||||
s = "a « b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a "b', n
|
||||
|
||||
# s/«/\"/g;
|
||||
s = "a«b"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == 'a"b', n
|
||||
|
||||
# s/ » /\" /g;
|
||||
s = " » "
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '" ', n
|
||||
|
||||
# s/ »/\"/g;
|
||||
s = " »"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/»/\"/g;
|
||||
s = "»"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == '"', n
|
||||
|
||||
# s/ \%/\%/g;
|
||||
s = " %"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "%", n
|
||||
|
||||
# s/ :/:/g;
|
||||
s = " :"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == ":", n
|
||||
|
||||
# s/(\d) (\d)/$1.$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="en")
|
||||
assert n == "2.3", n
|
||||
|
||||
# s/(\d) (\d)/$1,$2/g;
|
||||
s = "2 3"
|
||||
n = normalize_punctuation(s, lang="de")
|
||||
assert n == "2,3", n
|
||||
|
||||
|
||||
def main():
|
||||
test_normalize_punctuation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from remove_non_native_characters import remove_non_native_characters
|
||||
|
||||
|
||||
def test_remove_non_native_characters():
|
||||
s = "Ich heiße xxx好的01 fangjun".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ich heisse xxx fangjun", n
|
||||
|
||||
s = "äÄ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "aeae", n
|
||||
|
||||
s = "öÖ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "oeoe", n
|
||||
|
||||
s = "üÜ".lower()
|
||||
n = remove_non_native_characters(s, lang="de")
|
||||
assert n == "ueue", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_non_native_characters()
|
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
17
egs/must_c/ST/local/test_remove_punctuation.py
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from remove_punctuation import remove_punctuation
|
||||
|
||||
|
||||
def test_remove_punctuation():
|
||||
s = "a,b'c!#"
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab'c", n
|
||||
|
||||
s = " ab " # remove leading and trailing spaces
|
||||
n = remove_punctuation(s)
|
||||
assert n == "ab", n
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_remove_punctuation()
|
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/validate_bpe_lexicon.py
|
173
egs/must_c/ST/prepare.sh
Executable file
173
egs/must_c/ST/prepare.sh
Executable file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=10
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
version=v1.0
|
||||
tgt_lang=de
|
||||
dl_dir=$PWD/download
|
||||
|
||||
must_c_dir=$dl_dir/must-c/$version/en-$tgt_lang/data
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files.
|
||||
# - $dl_dir/must-c/$version/en-$tgt_lang/data/{dev,train,tst-COMMON,tst-HE}
|
||||
#
|
||||
# Please go to https://ict.fbk.eu/must-c-releases/
|
||||
# to download and untar the dataset if you have not already done this.
|
||||
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate
|
||||
# data/lang_bpe_${tgt_lang}_xxx
|
||||
# data/lang_bpe_${tgt_lang}_yyy
|
||||
# if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ ! -d $must_c_dir ]; then
|
||||
log "$must_c_dir does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for d in dev train tst-COMMON tst-HE; do
|
||||
if [ ! -d $must_c_dir/$d ]; then
|
||||
log "$must_c_dir/$d does not exist!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download musan"
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare must-c $version manifest for target language $tgt_lang"
|
||||
mkdir -p data/manifests/$version
|
||||
if [ ! -e data/manifests/$version/.${tgt_lang}.manifests.done ]; then
|
||||
lhotse prepare must-c \
|
||||
-j $nj \
|
||||
--tgt-lang $tgt_lang \
|
||||
$dl_dir/must-c/$version/ \
|
||||
data/manifests/$version/
|
||||
|
||||
touch data/manifests/$version/.${tgt_lang}.manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Text normalization for $version with target language $tgt_lang"
|
||||
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
|
||||
./local/preprocess_must_c.py \
|
||||
--manifest-dir ./data/manifests/$version/ \
|
||||
--tgt-lang $tgt_lang
|
||||
touch ./data/manifests/$version/.$tgt_lang.norm.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
|
||||
mkdir -p data/fbank/$version/
|
||||
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj
|
||||
|
||||
./local/compute_fbank_must_c.py \
|
||||
--in-dir ./data/manifests/$version/ \
|
||||
--out-dir ./data/fbank/$version/ \
|
||||
--tgt-lang $tgt_lang \
|
||||
--num-jobs $nj \
|
||||
--perturb-speed 1
|
||||
|
||||
touch data/fbank/$version/.$tgt_lang.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
|
||||
mkdir -p $lang_dir
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/words.txt ]; then
|
||||
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
|
||||
log "Validating $lang_dir/lexicon.txt"
|
||||
./local/validate_bpe_lexicon.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
done
|
||||
fi
|
1
egs/must_c/ST/shared
Symbolic link
1
egs/must_c/ST/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
Loading…
x
Reference in New Issue
Block a user