From 594abc7975ebcb5e225c298567f7b6f1bf8e8df8 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 27 Feb 2024 16:05:58 +0800 Subject: [PATCH] add files --- egs/mls/ASR/local/train_bpe_model.py | 114 ++++++++++++++++++++++++ egs/mls/ASR/zipformer/asr_datamodule.py | 4 +- 2 files changed, 116 insertions(+), 2 deletions(-) create mode 100755 egs/mls/ASR/local/train_bpe_model.py diff --git a/egs/mls/ASR/local/train_bpe_model.py b/egs/mls/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..55ff1d794 --- /dev/null +++ b/egs/mls/ASR/local/train_bpe_model.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--byte-fallback", + action="store_true", + help="""Whether to enable byte_fallback when training bpe.""", + ) + + parser.add_argument( + "--character-coverage", + type=float, + default=1.0, + help="Character coverage in vocabulary.", + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=args.character_coverage, + user_defined_symbols=user_defined_symbols, + byte_fallback=args.byte_fallback, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/mls/ASR/zipformer/asr_datamodule.py b/egs/mls/ASR/zipformer/asr_datamodule.py index e5c266763..1e16e8077 100644 --- a/egs/mls/ASR/zipformer/asr_datamodule.py +++ b/egs/mls/ASR/zipformer/asr_datamodule.py @@ -86,8 +86,8 @@ class MLSAsrDataModule: "--language", type=str2bool, default="all", - choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish"], - help="""Used only when --mini-libri is False.When enabled, + choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish", "all"], + help="""If all, use all the languages, other use 960h LibriSpeech. Otherwise, use 100h subset.""", ) group.add_argument(