Fix style issues.

This commit is contained in:
Fangjun Kuang 2021-11-17 12:24:35 +08:00
parent 469b665a5a
commit b29e4bdd03

View File

@ -25,8 +25,8 @@ representation of a dict with the following format:
""" """
import argparse import argparse
import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -43,12 +43,14 @@ def get_args():
parser.add_argument( parser.add_argument(
"lm_data", "lm_data",
type=str, type=str,
help="""Input LM training data as text, e.g. data/downloads/lm/librispeech-lm-norm.txt""", help="""Input LM training data as text, e.g.
data/downloads/lm/librispeech-lm-norm.txt""",
) )
parser.add_argument( parser.add_argument(
"lm_archive", "lm_archive",
type=str, type=str,
help="""Path to output archive, e.g. lm_data.pt; look at the source of this script to see the format.""", help="""Path to output archive, e.g. lm_data.pt;
look at the source of this script to see the format.""",
) )
return parser.parse_args() return parser.parse_args()
@ -57,6 +59,10 @@ def get_args():
def main(): def main():
args = get_args() args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model) sp.load(args.bpe_model)
@ -76,7 +82,7 @@ def main():
break break
line_words = line.split() line_words = line.split()
for w in line_words: for w in line_words:
if not w in word2index: if w not in word2index:
w_bpe = sp.Encode(w) w_bpe = sp.Encode(w)
word2index[w] = len(words2bpe) word2index[w] = len(words2bpe)
words2bpe.append(w_bpe) words2bpe.append(w_bpe)
@ -91,6 +97,12 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main() main()