diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8342d5212..8ce1eb478 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -107,7 +107,7 @@ fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare musan manifest" # We assume that you have downloaded the musan corpus - # to data/musan + # to $dl_dir/musan mkdir -p data/manifests if [ ! -e data/manifests/.musan.done ]; then lhotse prepare musan $dl_dir/musan data/manifests diff --git a/egs/must_c/ST/local/compute_fbank_musan.py b/egs/must_c/ST/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/must_c/ST/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/must_c/ST/local/compute_fbank_must_c.py b/egs/must_c/ST/local/compute_fbank_must_c.py new file mode 100755 index 000000000..0280d4a71 --- /dev/null +++ b/egs/must_c/ST/local/compute_fbank_must_c.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file computes fbank features of the MuST-C dataset. +It looks for manifests in the directory "in_dir" and write +generated features to "out_dir". +""" +import argparse +import logging +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + FeatureSet, + LilcomChunkyWriter, + load_manifest, +) + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-dir", + type=Path, + required=True, + help="Input manifest directory", + ) + + parser.add_argument( + "--out-dir", + type=Path, + required=True, + help="Output directory where generated fbank features are saved.", + ) + + parser.add_argument( + "--tgt-lang", + type=str, + required=True, + help="Target language, e.g., zh, de, fr.", + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=1, + help="Number of jobs for computing features", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_must_c( + in_dir: Path, + out_dir: Path, + tgt_lang: str, + num_jobs: int, + perturb_speed: bool, +): + out_dir.mkdir(parents=True, exist_ok=True) + + extractor = Fbank(FbankConfig(num_mel_bins=80)) + + parts = ["dev", "tst-COMMON", "tst-HE", "train"] + + prefix = "must_c" + suffix = "jsonl.gz" + for p in parts: + logging.info(f"Processing {p}") + + cuts_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}" + if perturb_speed and p == "train": + cuts_path += "_sp" + + cuts_path += ".jsonl.gz" + + if Path(cuts_path).is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + recordings_filename = in_dir / f"{prefix}_recordings_en-{tgt_lang}_{p}.jsonl.gz" + supervisions_filename = ( + in_dir / f"{prefix}_supervisions_en-{tgt_lang}_{p}_norm_rm.jsonl.gz" + ) + assert recordings_filename.is_file(), recordings_filename + assert supervisions_filename.is_file(), supervisions_filename + cut_set = CutSet.from_manifests( + recordings=load_manifest(recordings_filename), + supervisions=load_manifest(supervisions_filename), + ) + if perturb_speed and p == "train": + logging.info("Speed perturbing for the train dataset") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}_sp" + else: + storage_path = f"{out_dir}/{prefix}_feats_en-{tgt_lang}_{p}" + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=storage_path, + num_jobs=num_jobs, + storage_type=LilcomChunkyWriter, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +def main(): + args = get_args() + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + logging.info(vars(args)) + assert args.in_dir.is_dir(), args.in_dir + + compute_fbank_must_c( + in_dir=args.in_dir, + out_dir=args.out_dir, + tgt_lang=args.tgt_lang, + num_jobs=args.num_jobs, + perturb_speed=args.perturb_speed, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/must_c/ST/local/get_text.py b/egs/must_c/ST/local/get_text.py new file mode 100755 index 000000000..558ab6de8 --- /dev/null +++ b/egs/must_c/ST/local/get_text.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +""" +This file prints the text field of supervisions from cutset to the console +""" + +import argparse + +from lhotse import load_manifest_lazy +from pathlib import Path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "manifest", + type=Path, + help="Input manifest", + ) + return parser.parse_args() + + +def main(): + args = get_args() + assert args.manifest.is_file(), args.manifest + + cutset = load_manifest_lazy(args.manifest) + for c in cutset: + for sup in c.supervisions: + print(sup.text) + + +if __name__ == "__main__": + main() diff --git a/egs/must_c/ST/local/get_words.py b/egs/must_c/ST/local/get_words.py new file mode 100755 index 000000000..8e9e000a8 --- /dev/null +++ b/egs/must_c/ST/local/get_words.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +""" +This file generates words.txt from the given transcript file. +""" + +import argparse + +from pathlib import Path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "transcript", + type=Path, + help="Input transcript file", + ) + return parser.parse_args() + + +def main(): + args = get_args() + assert args.transcript.is_file(), args.transcript + + word_set = set() + with open(args.transcript) as f: + for line in f: + words = line.strip().split() + for w in words: + word_set.add(w) + + # Note: reserved* should be keep in sync with ./local/prepare_lang_bpe.py + reserved1 = ["", "!SIL", "", ""] + reserved2 = ["#0", "", ""] + + for w in reserved1 + reserved2: + assert w not in word_set, w + + words = sorted(list(word_set)) + words = reserved1 + words + reserved2 + + for i, w in enumerate(words): + print(w, i) + + +if __name__ == "__main__": + main() diff --git a/egs/must_c/ST/local/prepare_lang.py b/egs/must_c/ST/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/must_c/ST/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/must_c/ST/local/prepare_lang_bpe.py b/egs/must_c/ST/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/must_c/ST/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/must_c/ST/local/preprocess_must_c.py b/egs/must_c/ST/local/preprocess_must_c.py index 10d0ba5c3..f18029406 100755 --- a/egs/must_c/ST/local/preprocess_must_c.py +++ b/egs/must_c/ST/local/preprocess_must_c.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) """ This script normalizes transcripts from supervisions. @@ -11,11 +12,13 @@ Usage: import argparse import logging import re -from pathlib import Path from functools import partial +from pathlib import Path -from normalize_punctuation import normalize_punctuation from lhotse.recipes.utils import read_manifests_if_cached +from normalize_punctuation import normalize_punctuation +from remove_non_native_characters import remove_non_native_characters +from remove_punctuation import remove_punctuation def get_args(): @@ -39,6 +42,9 @@ def preprocess_must_c(manifest_dir: Path, tgt_lang: str): print(manifest_dir) normalize_punctuation_lang = partial(normalize_punctuation, lang=tgt_lang) + remove_non_native_characters_lang = partial( + remove_non_native_characters, lang=tgt_lang + ) prefix = "must_c" suffix = "jsonl.gz" @@ -66,7 +72,10 @@ def preprocess_must_c(manifest_dir: Path, tgt_lang: str): supervisions = manifests[name]["supervisions"] supervisions = supervisions.transform_text(normalize_punctuation_lang) + supervisions = supervisions.transform_text(remove_punctuation) supervisions = supervisions.transform_text(lambda x: x.lower()) + supervisions = supervisions.transform_text(remove_non_native_characters_lang) + supervisions = supervisions.transform_text(lambda x: re.sub(" +", " ", x)) supervisions.to_file(dst_name) diff --git a/egs/must_c/ST/local/remove_non_native_characters.py b/egs/must_c/ST/local/remove_non_native_characters.py new file mode 100755 index 000000000..f61fbd16b --- /dev/null +++ b/egs/must_c/ST/local/remove_non_native_characters.py @@ -0,0 +1,21 @@ +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import re + + +def remove_non_native_characters(s: str, lang: str): + if lang == "de": + # ä -> ae + # ö -> oe + # ü -> ue + # ß -> ss + + s = re.sub("ä", "ae", s) + s = re.sub("ö", "oe", s) + s = re.sub("ü", "ue", s) + s = re.sub("ß", "ss", s) + # keep only a-z and spaces + # note: ' is removed + s = re.sub(r"[^a-z\s]", "", s) + + return s diff --git a/egs/must_c/ST/local/test_normalize_punctuation.py b/egs/must_c/ST/local/test_normalize_punctuation.py index 28ecfb4f1..9079858c8 100755 --- a/egs/must_c/ST/local/test_normalize_punctuation.py +++ b/egs/must_c/ST/local/test_normalize_punctuation.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) from normalize_punctuation import normalize_punctuation diff --git a/egs/must_c/ST/local/test_remove_non_native_characters.py b/egs/must_c/ST/local/test_remove_non_native_characters.py new file mode 100755 index 000000000..e913fc967 --- /dev/null +++ b/egs/must_c/ST/local/test_remove_non_native_characters.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from remove_non_native_characters import remove_non_native_characters + + +def test_remove_non_native_characters(): + s = "Ich heiße xxx好的01 fangjun".lower() + n = remove_non_native_characters(s, lang="de") + assert n == "ich heisse xxx fangjun", n + + s = 'äÄ'.lower() + n = remove_non_native_characters(s, lang="de") + assert n == 'aeae', n + + s = 'öÖ'.lower() + n = remove_non_native_characters(s, lang="de") + assert n == 'oeoe', n + + s = 'üÜ'.lower() + n = remove_non_native_characters(s, lang="de") + assert n == 'ueue', n + + +if __name__ == "__main__": + test_remove_non_native_characters() diff --git a/egs/must_c/ST/local/train_bpe_model.py b/egs/must_c/ST/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/must_c/ST/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/must_c/ST/local/validate_bpe_lexicon.py b/egs/must_c/ST/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/must_c/ST/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/must_c/ST/prepare.sh b/egs/must_c/ST/prepare.sh index 75a9aa27e..7f98655af 100755 --- a/egs/must_c/ST/prepare.sh +++ b/egs/must_c/ST/prepare.sh @@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=10 -stage=-1 +stage=0 stop_stage=100 version=v1.0 @@ -101,8 +101,73 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Text normalization" - ./local/preprocess_must_c.py \ - --manifest-dir ./data/manifests/$version/ \ - --tgt-lang $tgt_lang + log "Stage 3: Text normalization for $version with target language $tgt_lang" + if [ ! -f ./data/manifests/$version/.$tgt_lang.norm.done ]; then + ./local/preprocess_must_c.py \ + --manifest-dir ./data/manifests/$version/ \ + --tgt-lang $tgt_lang + touch ./data/manifests/$version/.$tgt_lang.norm.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for $version with target language $tgt_lang" + mkdir -p data/fbank/$version/ + if [ ! -e data/fbank/$version/.$tgt_lang.done ]; then + ./local/compute_fbank_must_c.py \ + --in-dir ./data/manifests/$version/ \ + --out-dir ./data/fbank/$version/ \ + --tgt-lang $tgt_lang \ + --num-jobs $nj + + ./local/compute_fbank_must_c.py \ + --in-dir ./data/manifests/$version/ \ + --out-dir ./data/fbank/$version/ \ + --tgt-lang $tgt_lang \ + --num-jobs $nj \ + --perturb-speed 1 + + touch data/fbank/$version/.$tgt_lang.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang for $version with target language $tgt_lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size}/$version/$tgt_lang/ + mkdir -p $lang_dir + if [ ! -f $lang_dir/transcript_words.txt ]; then + ./local/get_text.py ./data/fbank/$version/must_c_feats_en-${tgt_lang}_train.jsonl.gz > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + ./local/get_words.py $lang_dir/transcript_words.txt > $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + done fi