mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Data prepare pipeline
This commit is contained in:
parent
e597be6867
commit
915c4e9d87
242
egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
Executable file
242
egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
Executable file
@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the Libriheavy dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
help="""The source directory that contains raw manifests.
|
||||
""",
|
||||
default="data/manifests",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fbank-dir",
|
||||
type=str,
|
||||
help="""Fbank output dir
|
||||
""",
|
||||
default="data/fbank",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of dataloading workers used for reading the audio.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use speed perturbation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-splits",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to compute fbank on splits.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
help="""The number of splits of the medium and large subset.
|
||||
Only needed when --use-splits is true.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Process pieces starting from this number (inclusive).
|
||||
Only needed when --use-splits is true.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="""Stop processing pieces until this number (exclusive).
|
||||
Only needed when --use-splits is true.""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_libriheavy(args):
|
||||
src_dir = Path(args.manifest_dir)
|
||||
output_dir = Path(args.fbank_dir)
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
subset = args.subset
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
|
||||
if output_cuts_path.exists():
|
||||
logging.info(f"{output_cuts_path} exists - skipping")
|
||||
return
|
||||
|
||||
input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
|
||||
assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!"
|
||||
logging.info(f"Loading {input_cuts_path}")
|
||||
cut_set = CutSet.from_file(input_cuts_path)
|
||||
|
||||
logging.info("Computing features")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/libriheavy_feats_{subset}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {output_cuts_path}")
|
||||
cut_set.to_file(output_cuts_path)
|
||||
|
||||
|
||||
def compute_fbank_libriheavy_splits(args):
|
||||
num_splits = args.num_splits
|
||||
subset = args.subset
|
||||
src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split"
|
||||
src_dir = Path(src_dir)
|
||||
output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split"
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
if stop < start:
|
||||
stop = num_splits
|
||||
|
||||
stop = min(stop, num_splits)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
|
||||
if not raw_cuts_path.is_file():
|
||||
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
||||
continue
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Computing features")
|
||||
if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists():
|
||||
logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca")
|
||||
os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
if args.use_splits:
|
||||
assert args.num_splits is not None, "Please provide num_splits"
|
||||
compute_fbank_libriheavy_splits(args)
|
||||
else:
|
||||
compute_fbank_libriheavy(args)
|
1
egs/libriheavy/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/libriheavy/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
38
egs/libriheavy/ASR/local/prepare_manifest.py
Executable file
38
egs/libriheavy/ASR/local/prepare_manifest.py
Executable file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Assign text of the supervisions and remove unnecessary entries.
|
||||
def main():
|
||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
||||
fname = Path(sys.argv[1]).name
|
||||
oname = Path(sys.argv[2]) / fname
|
||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||
for line in fin:
|
||||
cut = json.loads(line)
|
||||
cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0]
|
||||
del cut["supervisions"][0]["custom"]
|
||||
del cut["custom"]
|
||||
fout.write((json.dumps(cut) + "\n").encode())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
73
egs/libriheavy/ASR/local/prepare_text.py
Executable file
73
egs/libriheavy/ASR/local/prepare_text.py
Executable file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import sys
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="""Path to the input text.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action='store_true',
|
||||
help="""Whether to normalize the text.
|
||||
True to normalize the text to upper and remove all punctuation.
|
||||
"""
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def simple_cleanup(text: str) -> str:
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_punc_to_upper(text: str) -> str:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||
s = " ".join("".join(s_list).split()).strip()
|
||||
return s
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
|
||||
line = f.readline()
|
||||
while line:
|
||||
if args.normalize:
|
||||
print(remove_punc_to_upper(line))
|
||||
else:
|
||||
print(simple_cleanup(line))
|
||||
line = f.readline()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -42,6 +42,19 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--byte-fallback",
|
||||
action='store_true',
|
||||
help="""Whether to enable byte_fallback when training bpe."""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--character-coverage",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Character coverage in vocabulary.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
@ -66,7 +79,6 @@ def main():
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
@ -83,8 +95,9 @@ def main():
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
character_coverage=args.character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
byte_fallback=args.byte_fallback,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
|
@ -8,6 +8,7 @@ set -eou pipefail
|
||||
nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
export CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
@ -43,6 +44,8 @@ vocab_sizes=(
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
fbank_dir=data/fbank
|
||||
manifests_dir=data/manifests
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
@ -80,7 +83,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
# ln -sfv /path/to/libriheavy $dl_dir/libriheavy
|
||||
#
|
||||
mkdir -p $dl_dir/libriheavy
|
||||
for subset in small medium large dev test_clean test_other test_clean_large test_other_large; do
|
||||
for subset in small medium large dev test_clean test_other; do
|
||||
if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
log "Downloading ${subset} subset."
|
||||
wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz
|
||||
@ -118,18 +121,147 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
mkdir -p $manifests_dir
|
||||
if [ ! -e $manifests_dir/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan $manifests_dir
|
||||
touch $manifests_dir/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare Libriheavy manifests"
|
||||
mkdir -p $manifests_dir
|
||||
for subset in small medium large dev test_clean test_other; do
|
||||
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
log "Prepare manifest for subset : ${subset}"
|
||||
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
mkdir -p $fbank_dir
|
||||
if [ ! -e $fbank_dir/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
touch $fbank_dir/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for small subset and validation subsets"
|
||||
for subset in test_clean test_other dev small; do
|
||||
log "Computing $subset subset."
|
||||
if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
|
||||
./local/compute_fbank_libriheavy.py \
|
||||
--manifest-dir ${manifests_dir} \
|
||||
--subset ${subset} \
|
||||
--fbank-dir $fbank_dir \
|
||||
--num-workers $nj
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
num_per_split=8000
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Split medium and large subsets."
|
||||
for subset in medium large; do
|
||||
log "Spliting subset : $subset"
|
||||
split_dir=$manifests_dir/libriheavy_${subset}_split
|
||||
mkdir -p $split_dir
|
||||
if [ ! -e $split_dir/.split_completed ]; then
|
||||
lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.split_completed
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Compute fbank for medium and large subsets"
|
||||
mkdir -p $fbank_dir
|
||||
chunk_size=20
|
||||
for subset in medium large; do
|
||||
if [ $subset == "large" ]; then
|
||||
chunk_size=200
|
||||
fi
|
||||
num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l)
|
||||
if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
|
||||
for i in $(seq 0 1 6); do
|
||||
start=$(( i * $chunk_size ))
|
||||
end=$(( (i+1) * $chunk_size ))
|
||||
./local/compute_fbank_libriheavy.py \
|
||||
--manifest-dir ${manifests_dir} \
|
||||
--use-splits 1 \
|
||||
--subset ${subset} \
|
||||
--fbank-dir $fbank_dir \
|
||||
--num-splits $num_splits \
|
||||
--num-workers $nj \
|
||||
--start $start \
|
||||
--stop $end &
|
||||
done
|
||||
wait
|
||||
touch $fbank_dir/.libriheavy.${subset}.done
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Combine features for medium and large subsets."
|
||||
for subset in medium large; do
|
||||
if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz")
|
||||
lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Train BPE model for normalized text"
|
||||
|
||||
if [ ! -f data/texts ]; then
|
||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||
| ./local/prepare_text.py --normalize > data/texts
|
||||
fi
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp data/texts $lang_dir/text
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Train BPE model for unnormalized text"
|
||||
if [ ! -f data/punc_texts ]; then
|
||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||
| ./local/prepare_text.py > data/punc_texts
|
||||
fi
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
new_vacab_size = $(($vocab_size + 256))
|
||||
lang_dir=data/lang_punc_bpe_${new_vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp data/punc_texts $lang_dir/text
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--byte-fallback \
|
||||
--vocab-size ${new_vocab_size} \
|
||||
--byte-fallback \
|
||||
--character-coverage 0.99 \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user