fix sytle

This commit is contained in:
PingFeng Luo 2021-12-31 18:50:12 +08:00
parent 503275e649
commit 28d1e8660e
4 changed files with 32 additions and 56 deletions

View File

@ -1,43 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Pingfeng Luo)
import argparse
import re
from pathlib import Path
from typing import Dict, List
from pypinyin import pinyin, lazy_pinyin, Style
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
return parser.parse_args()
def process_line(
line: str
) -> None:
"""
Args:
line:
A line of transcript consisting of space(s) separated words.
Returns:
Return None.
"""
char = line.strip().split()[0]
syllables = pinyin(char, style=Style.TONE3, heteronym=True)
syllables = ''.join(str(syllables[0][:]))
for s in syllables.split(',') :
print("{} {}".format(char, re.sub(r'[]', '', s)))
def main():
args = get_args()
assert Path(args.lexicon).is_file()
with open(args.lexicon) as f:
for line in f:
process_line(line=line)
if __name__ == "__main__":
main()

View File

@ -40,7 +40,8 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument("--space", default="<space>", type=str, help="space symbol") parser.add_argument("--space", default="<space>", type=str,
help="space symbol")
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -48,19 +49,15 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument("text", type=str, default=False, nargs="?",
help="input text")
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
type=str, type=str,
default="char", default="char",
choices=["char", "phn"], choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - help="""Transcript type. char/phn""",
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """,
) )
return parser return parser
@ -78,7 +75,9 @@ def main():
if args.text: if args.text:
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer

View File

@ -40,6 +40,7 @@ from icefall.utils import (
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats, write_error_stats,
str2bool,
) )
@ -108,6 +109,16 @@ def get_parser():
default=3, default=3,
help="Maximum number of symbols per frame", help="Maximum number of symbols per frame",
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser return parser
@ -417,6 +428,13 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
@ -427,7 +445,9 @@ def main():
wenetspeech = WenetSpeechDataModule(args) wenetspeech = WenetSpeechDataModule(args)
test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts()) test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts())
test_meetting_dl = wenetspeech.test_dataloaders(wenetspeech.test_meetting_cuts()) test_meetting_dl = wenetspeech.test_dataloaders(
wenetspeech.test_meetting_cuts()
)
test_sets = ["TEST_NET", "TEST_MEETTING"] test_sets = ["TEST_NET", "TEST_MEETTING"]
test_dls = [test_net_dl, test_meetting_dl] test_dls = [test_net_dl, test_meetting_dl]