mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add data preparation for the MuST-C corpus
This commit is contained in:
parent
1ce9a8b3c4
commit
14c938aa07
@ -107,7 +107,7 @@ fi
|
|||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Prepare musan manifest"
|
log "Stage 2: Prepare musan manifest"
|
||||||
# We assume that you have downloaded the musan corpus
|
# We assume that you have downloaded the musan corpus
|
||||||
# to data/musan
|
# to $dl_dir/musan
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
if [ ! -e data/manifests/.musan.done ]; then
|
if [ ! -e data/manifests/.musan.done ]; then
|
||||||
lhotse prepare musan $dl_dir/musan data/manifests
|
lhotse prepare musan $dl_dir/musan data/manifests
|
||||||
|
|||||||
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
1
egs/must_c/ST/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||||
148
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
148
egs/must_c/ST/local/compute_fbank_must_c.py
Executable file
@ -0,0 +1,148 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the MuST-C dataset.
|
||||||
|
It looks for manifests in the directory "in_dir" and write
|
||||||
|
generated features to "out_dir".
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
FeatureSet,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
load_manifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--in-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Input manifest directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Output directory where generated fbank features are saved.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tgt-lang",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Target language, e.g., zh, de, fr.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-jobs",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of jobs for computing features",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_must_c(
|
||||||
|
in_dir: Path,
|
||||||
|
out_dir: Path,
|
||||||
|
tgt_lang: str,
|
||||||
|
num_jobs: int,
|
||||||
|
perturb_speed: bool,
|
||||||
|
):
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
|
||||||
|
parts = ["dev", "tst-COMMON", "tst-HE", "train"]
|
||||||
|
|
||||||
|
prefix = "must_c"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
for p in parts:
|
||||||
|
logging.info(f"Processing {p}")
|
||||||
|
|
||||||
|
cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||||
|
if perturb_speed and p == "train":
|
||||||
|
cuts_path += "_sp"
|
||||||
|
|
||||||
|
cuts_path += ".jsonl.gz"
|
||||||
|
|
||||||
|
if Path(cuts_path).is_file():
|
||||||
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz"
|
||||||
|
supervisions_filename = (
|
||||||
|
in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz"
|
||||||
|
)
|
||||||
|
assert recordings_filename.is_file(), recordings_filename
|
||||||
|
assert supervisions_filename.is_file(), supervisions_filename
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=load_manifest(recordings_filename),
|
||||||
|
supervisions=load_manifest(supervisions_filename),
|
||||||
|
)
|
||||||
|
if perturb_speed and p == "train":
|
||||||
|
logging.info("Speed perturbing for the train dataset")
|
||||||
|
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
|
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp"
|
||||||
|
else:
|
||||||
|
storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}"
|
||||||
|
|
||||||
|
cut_set = cut_set.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=storage_path,
|
||||||
|
num_jobs=num_jobs,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Saving to {cuts_path}")
|
||||||
|
cut_set.to_file(cuts_path)
|
||||||
|
logging.info(f"Saved to {cuts_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
logging.info(vars(args))
|
||||||
|
assert args.in_dir.is_dir(), args.in_dir
|
||||||
|
|
||||||
|
compute_fbank_must_c(
|
||||||
|
in_dir=args.in_dir,
|
||||||
|
out_dir=args.out_dir,
|
||||||
|
tgt_lang=args.tgt_lang,
|
||||||
|
num_jobs=args.num_jobs,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
34
egs/must_c/ST/local/get_text.py
Executable file
34
egs/must_c/ST/local/get_text.py
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
"""
|
||||||
|
This file prints the text field of supervisions from cutset to the console
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from lhotse import load_manifest_lazy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"manifest",
|
||||||
|
type=Path,
|
||||||
|
help="Input manifest",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert args.manifest.is_file(), args.manifest
|
||||||
|
|
||||||
|
cutset = load_manifest_lazy(args.manifest)
|
||||||
|
for c in cutset:
|
||||||
|
for sup in c.supervisions:
|
||||||
|
print(sup.text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
48
egs/must_c/ST/local/get_words.py
Executable file
48
egs/must_c/ST/local/get_words.py
Executable file
@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
"""
|
||||||
|
This file generates words.txt from the given transcript file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"transcript",
|
||||||
|
type=Path,
|
||||||
|
help="Input transcript file",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert args.transcript.is_file(), args.transcript
|
||||||
|
|
||||||
|
word_set = set()
|
||||||
|
with open(args.transcript) as f:
|
||||||
|
for line in f:
|
||||||
|
words = line.strip().split()
|
||||||
|
for w in words:
|
||||||
|
word_set.add(w)
|
||||||
|
|
||||||
|
# Note: reserved* should be keep in sync with ./local/prepare_lang_bpe.py
|
||||||
|
reserved1 = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
|
||||||
|
reserved2 = ["#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
for w in reserved1 + reserved2:
|
||||||
|
assert w not in word_set, w
|
||||||
|
|
||||||
|
words = sorted(list(word_set))
|
||||||
|
words = reserved1 + words + reserved2
|
||||||
|
|
||||||
|
for i, w in enumerate(words):
|
||||||
|
print(w, i)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/prepare_lang.py
|
||||||
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
1
egs/must_c/ST/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
||||||
@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
"""
|
"""
|
||||||
This script normalizes transcripts from supervisions.
|
This script normalizes transcripts from supervisions.
|
||||||
|
|
||||||
@ -11,11 +12,13 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from normalize_punctuation import normalize_punctuation
|
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
from normalize_punctuation import normalize_punctuation
|
||||||
|
from remove_non_native_characters import remove_non_native_characters
|
||||||
|
from remove_punctuation import remove_punctuation
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -39,6 +42,9 @@ def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
|||||||
print(manifest_dir)
|
print(manifest_dir)
|
||||||
|
|
||||||
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang)
|
||||||
|
remove_non_native_characters_lang = partial(
|
||||||
|
remove_non_native_characters, lang=tgt_lang
|
||||||
|
)
|
||||||
|
|
||||||
prefix = "must_c"
|
prefix = "must_c"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
@ -66,7 +72,10 @@ def preprocess_must_c(manifest_dir: Path, tgt_lang: str):
|
|||||||
|
|
||||||
supervisions = manifests[name]["supervisions"]
|
supervisions = manifests[name]["supervisions"]
|
||||||
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
supervisions = supervisions.transform_text(normalize_punctuation_lang)
|
||||||
|
supervisions = supervisions.transform_text(remove_punctuation)
|
||||||
supervisions = supervisions.transform_text(lambda x: x.lower())
|
supervisions = supervisions.transform_text(lambda x: x.lower())
|
||||||
|
supervisions = supervisions.transform_text(remove_non_native_characters_lang)
|
||||||
|
supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x))
|
||||||
|
|
||||||
supervisions.to_file(dst_name)
|
supervisions.to_file(dst_name)
|
||||||
|
|
||||||
|
|||||||
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
21
egs/must_c/ST/local/remove_non_native_characters.py
Executable file
@ -0,0 +1,21 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def remove_non_native_characters(s: str, lang: str):
|
||||||
|
if lang == "de":
|
||||||
|
# ä -> ae
|
||||||
|
# ö -> oe
|
||||||
|
# ü -> ue
|
||||||
|
# ß -> ss
|
||||||
|
|
||||||
|
s = re.sub("ä", "ae", s)
|
||||||
|
s = re.sub("ö", "oe", s)
|
||||||
|
s = re.sub("ü", "ue", s)
|
||||||
|
s = re.sub("ß", "ss", s)
|
||||||
|
# keep only a-z and spaces
|
||||||
|
# note: ' is removed
|
||||||
|
s = re.sub(r"[^a-z\s]", "", s)
|
||||||
|
|
||||||
|
return s
|
||||||
@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
from normalize_punctuation import normalize_punctuation
|
from normalize_punctuation import normalize_punctuation
|
||||||
|
|
||||||
|
|||||||
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
26
egs/must_c/ST/local/test_remove_non_native_characters.py
Executable file
@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from remove_non_native_characters import remove_non_native_characters
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_non_native_characters():
|
||||||
|
s = "Ich heiße xxx好的01 fangjun".lower()
|
||||||
|
n = remove_non_native_characters(s, lang="de")
|
||||||
|
assert n == "ich heisse xxx fangjun", n
|
||||||
|
|
||||||
|
s = 'äÄ'.lower()
|
||||||
|
n = remove_non_native_characters(s, lang="de")
|
||||||
|
assert n == 'aeae', n
|
||||||
|
|
||||||
|
s = 'öÖ'.lower()
|
||||||
|
n = remove_non_native_characters(s, lang="de")
|
||||||
|
assert n == 'oeoe', n
|
||||||
|
|
||||||
|
s = 'üÜ'.lower()
|
||||||
|
n = remove_non_native_characters(s, lang="de")
|
||||||
|
assert n == 'ueue', n
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_remove_non_native_characters()
|
||||||
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
1
egs/must_c/ST/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/train_bpe_model.py
|
||||||
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
1
egs/must_c/ST/local/validate_bpe_lexicon.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/validate_bpe_lexicon.py
|
||||||
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=10
|
nj=10
|
||||||
stage=-1
|
stage=0
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
|
||||||
version=v1.0
|
version=v1.0
|
||||||
@ -101,8 +101,73 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Text normalization"
|
log "Stage 3: Text normalization for $version with target language $tgt_lang"
|
||||||
|
if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then
|
||||||
./local/preprocess_must_c.py \
|
./local/preprocess_must_c.py \
|
||||||
--manifest-dir ./data/manifests/$version/ \
|
--manifest-dir ./data/manifests/$version/ \
|
||||||
--tgt-lang $tgt_lang
|
--tgt-lang $tgt_lang
|
||||||
|
touch ./data/manifests/$version/.$tgt_lang.norm.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Compute fbank for musan"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
if [ ! -e data/fbank/.musan.done ]; then
|
||||||
|
./local/compute_fbank_musan.py
|
||||||
|
touch data/fbank/.musan.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Compute fbank for $version with target language $tgt_lang"
|
||||||
|
mkdir -p data/fbank/$version/
|
||||||
|
if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then
|
||||||
|
./local/compute_fbank_must_c.py \
|
||||||
|
--in-dir ./data/manifests/$version/ \
|
||||||
|
--out-dir ./data/fbank/$version/ \
|
||||||
|
--tgt-lang $tgt_lang \
|
||||||
|
--num-jobs $nj
|
||||||
|
|
||||||
|
./local/compute_fbank_must_c.py \
|
||||||
|
--in-dir ./data/manifests/$version/ \
|
||||||
|
--out-dir ./data/fbank/$version/ \
|
||||||
|
--tgt-lang $tgt_lang \
|
||||||
|
--num-jobs $nj \
|
||||||
|
--perturb-speed 1
|
||||||
|
|
||||||
|
touch data/fbank/$version/.$tgt_lang.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang"
|
||||||
|
|
||||||
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
|
lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/
|
||||||
|
mkdir -p $lang_dir
|
||||||
|
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||||
|
./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/words.txt ]; then
|
||||||
|
./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/bpe.model ]; then
|
||||||
|
./local/train_bpe_model.py \
|
||||||
|
--lang-dir $lang_dir \
|
||||||
|
--vocab-size $vocab_size \
|
||||||
|
--transcript $lang_dir/transcript_words.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
|
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||||
|
|
||||||
|
log "Validating $lang_dir/lexicon.txt"
|
||||||
|
./local/validate_bpe_lexicon.py \
|
||||||
|
--lexicon $lang_dir/lexicon.txt \
|
||||||
|
--bpe-model $lang_dir/bpe.model
|
||||||
|
fi
|
||||||
|
done
|
||||||
fi
|
fi
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user