mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix sytle
This commit is contained in:
parent
503275e649
commit
28d1e8660e
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Pingfeng Luo)
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from pypinyin import pinyin, lazy_pinyin, Style
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_line(
|
||||
line: str
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
line:
|
||||
A line of transcript consisting of space(s) separated words.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
char = line.strip().split()[0]
|
||||
syllables = pinyin(char, style=Style.TONE3, heteronym=True)
|
||||
syllables = ''.join(str(syllables[0][:]))
|
||||
for s in syllables.split(',') :
|
||||
print("{} {}".format(char, re.sub(r'[]', '', s)))
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.lexicon).is_file()
|
||||
|
||||
with open(args.lexicon) as f:
|
||||
for line in f:
|
||||
process_line(line=line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -40,7 +40,8 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||
)
|
||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument("--space", default="<space>", type=str,
|
||||
help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
"-l",
|
||||
@ -48,19 +49,15 @@ def get_parser():
|
||||
type=str,
|
||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||
)
|
||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
||||
parser.add_argument("text", type=str, default=False, nargs="?",
|
||||
help="input text")
|
||||
parser.add_argument(
|
||||
"--trans_type",
|
||||
"-t",
|
||||
type=str,
|
||||
default="char",
|
||||
choices=["char", "phn"],
|
||||
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
|
||||
If trans_type is char,
|
||||
read from SI1279.WRD file -> "bricks are an alternative"
|
||||
Else if trans_type is phn,
|
||||
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
|
||||
sil t er n ih sil t ih v sil" """,
|
||||
help="""Transcript type. char/phn""",
|
||||
)
|
||||
return parser
|
||||
|
||||
@ -78,7 +75,9 @@ def main():
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
||||
f = codecs.getreader("utf-8")(
|
||||
sys.stdin if is_python2 else sys.stdin.buffer
|
||||
)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(
|
||||
sys.stdout if is_python2 else sys.stdout.buffer
|
||||
|
@ -40,6 +40,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
@ -108,6 +109,16 @@ def get_parser():
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -417,6 +428,13 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
@ -427,7 +445,9 @@ def main():
|
||||
wenetspeech = WenetSpeechDataModule(args)
|
||||
|
||||
test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts())
|
||||
test_meetting_dl = wenetspeech.test_dataloaders(wenetspeech.test_meetting_cuts())
|
||||
test_meetting_dl = wenetspeech.test_dataloaders(
|
||||
wenetspeech.test_meetting_cuts()
|
||||
)
|
||||
|
||||
test_sets = ["TEST_NET", "TEST_MEETTING"]
|
||||
test_dls = [test_net_dl, test_meetting_dl]
|
||||
|
Loading…
x
Reference in New Issue
Block a user