From a704a2758b66360ccb931711c264c429cef4601d Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 20 Jul 2023 12:11:03 +0800 Subject: [PATCH] added scripts for BPE model training --- .../ASR/local/prepare_for_bpe_model.py | 63 ++++++++++ egs/multi_zh-hans/ASR/local/text2token.py | 1 + .../ASR/local/train_bpe_model.py | 108 ++++++++++++++++++ egs/multi_zh-hans/ASR/prepare.sh | 19 ++- 4 files changed, 186 insertions(+), 5 deletions(-) create mode 100755 egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py create mode 120000 egs/multi_zh-hans/ASR/local/text2token.py create mode 100755 egs/multi_zh-hans/ASR/local/train_bpe_model.py diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..1d6934b61 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from tqdm.auto import tqdm +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="WenetSpeech training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + text_lines = fin.readlines() + tokenized_lines = [] + for line in tqdm(text_lines, desc="Tokenizing training transcript"): + tokenized_lines.append(f"{tokenize_by_CJK_char(line)}\n") + with open(transcript_path, "w+", encoding="utf-8") as fout: + fout.writelines(tokenized_lines) + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/text2token.py b/egs/multi_zh-hans/ASR/local/text2token.py new file mode 120000 index 000000000..ce5cfd537 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/train_bpe_model.py b/egs/multi_zh-hans/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..b651fc290 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/train_bpe_model.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + parser.add_argument( + "--byte-fallback", + type=bool, + default=True, + help="Enable byte fallback for BPE model.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 0.98 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + byte_fallback=args.byte_fallback, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 26b34ea21..33addf425 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -15,9 +15,7 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 vocab_sizes=( - # 2000 - # 1000 - 500 + 2000 ) @@ -185,7 +183,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ ! -f data/manifests/.magicdata.done ]; then mkdir -p data/manifests - lhotse prepare magicdata -j $nj $dl_dir/magicdata data/manifests/magicdata + lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata touch data/manifests/.magicdata.done fi @@ -246,9 +244,20 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET_raw.jsonl.gz) . cd ../.. else - log "Abort! Please run ../../wenetspeech/ASR/prepare.sh --stage 5 --stop-stage 5" + log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" exit 1 fi + + if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then + cd data + cp -r ../../../../wenetspeech/ASR/data/lang_char . + cd .. + else + log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" + exit 1 + fi + + fi log "Dataset: KeSpeech"