mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
added scripts for BPE model training
This commit is contained in:
parent
48303ed667
commit
a704a2758b
63
egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py
Executable file
63
egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py
Executable file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from icefall.utils import tokenize_by_CJK_char
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Output directory.
|
||||
The generated transcript_chars.txt is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="WenetSpeech training transcript.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text = Path(args.text)
|
||||
|
||||
assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!"
|
||||
|
||||
transcript_path = lang_dir / "transcript_chars.txt"
|
||||
|
||||
with open(text, "r", encoding="utf-8") as fin:
|
||||
text_lines = fin.readlines()
|
||||
tokenized_lines = []
|
||||
for line in tqdm(text_lines, desc="Tokenizing training transcript"):
|
||||
tokenized_lines.append(f"{tokenize_by_CJK_char(line)}\n")
|
||||
with open(transcript_path, "w+", encoding="utf-8") as fout:
|
||||
fout.writelines(tokenized_lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/multi_zh-hans/ASR/local/text2token.py
Symbolic link
1
egs/multi_zh-hans/ASR/local/text2token.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../wenetspeech/ASR/local/text2token.py
|
108
egs/multi_zh-hans/ASR/local/train_bpe_model.py
Executable file
108
egs/multi_zh-hans/ASR/local/train_bpe_model.py
Executable file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--byte-fallback",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Enable byte fallback for BPE model.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = args.transcript
|
||||
character_coverage = 0.98
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
byte_fallback=args.byte_fallback,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -15,9 +15,7 @@ dl_dir=$PWD/download
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
vocab_sizes=(
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
2000
|
||||
)
|
||||
|
||||
|
||||
@ -185,7 +183,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
|
||||
if [ ! -f data/manifests/.magicdata.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare magicdata -j $nj $dl_dir/magicdata data/manifests/magicdata
|
||||
lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata
|
||||
touch data/manifests/.magicdata.done
|
||||
fi
|
||||
|
||||
@ -246,9 +244,20 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET_raw.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh --stage 5 --stop-stage 5"
|
||||
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then
|
||||
cd data
|
||||
cp -r ../../../../wenetspeech/ASR/data/lang_char .
|
||||
cd ..
|
||||
else
|
||||
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
fi
|
||||
|
||||
log "Dataset: KeSpeech"
|
||||
|
Loading…
x
Reference in New Issue
Block a user