mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix sytle
This commit is contained in:
parent
503275e649
commit
28d1e8660e
@ -1,43 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Pingfeng Luo)
|
|
||||||
import argparse
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
from pypinyin import pinyin, lazy_pinyin, Style
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def process_line(
|
|
||||||
line: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
line:
|
|
||||||
A line of transcript consisting of space(s) separated words.
|
|
||||||
Returns:
|
|
||||||
Return None.
|
|
||||||
"""
|
|
||||||
char = line.strip().split()[0]
|
|
||||||
syllables = pinyin(char, style=Style.TONE3, heteronym=True)
|
|
||||||
syllables = ''.join(str(syllables[0][:]))
|
|
||||||
for s in syllables.split(',') :
|
|
||||||
print("{} {}".format(char, re.sub(r'[]', '', s)))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
assert Path(args.lexicon).is_file()
|
|
||||||
|
|
||||||
with open(args.lexicon) as f:
|
|
||||||
for line in f:
|
|
||||||
process_line(line=line)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -29,7 +29,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
|
|||||||
|
|
||||||
def normalize_text(
|
def normalize_text(
|
||||||
utt: str,
|
utt: str,
|
||||||
#punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
# punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
||||||
punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
||||||
whitespace_pattern=re.compile(r"\s\s+"),
|
whitespace_pattern=re.compile(r"\s\s+"),
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -40,7 +40,8 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
|
||||||
)
|
)
|
||||||
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
|
parser.add_argument("--space", default="<space>", type=str,
|
||||||
|
help="space symbol")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--non-lang-syms",
|
"--non-lang-syms",
|
||||||
"-l",
|
"-l",
|
||||||
@ -48,19 +49,15 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
|
||||||
)
|
)
|
||||||
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
|
parser.add_argument("text", type=str, default=False, nargs="?",
|
||||||
|
help="input text")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trans_type",
|
"--trans_type",
|
||||||
"-t",
|
"-t",
|
||||||
type=str,
|
type=str,
|
||||||
default="char",
|
default="char",
|
||||||
choices=["char", "phn"],
|
choices=["char", "phn"],
|
||||||
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
|
help="""Transcript type. char/phn""",
|
||||||
If trans_type is char,
|
|
||||||
read from SI1279.WRD file -> "bricks are an alternative"
|
|
||||||
Else if trans_type is phn,
|
|
||||||
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
|
|
||||||
sil t er n ih sil t ih v sil" """,
|
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -78,7 +75,9 @@ def main():
|
|||||||
if args.text:
|
if args.text:
|
||||||
f = codecs.open(args.text, encoding="utf-8")
|
f = codecs.open(args.text, encoding="utf-8")
|
||||||
else:
|
else:
|
||||||
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
|
f = codecs.getreader("utf-8")(
|
||||||
|
sys.stdin if is_python2 else sys.stdin.buffer
|
||||||
|
)
|
||||||
|
|
||||||
sys.stdout = codecs.getwriter("utf-8")(
|
sys.stdout = codecs.getwriter("utf-8")(
|
||||||
sys.stdout if is_python2 else sys.stdout.buffer
|
sys.stdout if is_python2 else sys.stdout.buffer
|
||||||
@ -88,7 +87,7 @@ def main():
|
|||||||
while line:
|
while line:
|
||||||
x = line.split()
|
x = line.split()
|
||||||
print(" ".join(x[: args.skip_ncols]), end=" ")
|
print(" ".join(x[: args.skip_ncols]), end=" ")
|
||||||
a = " ".join(x[args.skip_ncols :])
|
a = " ".join(x[args.skip_ncols:])
|
||||||
|
|
||||||
# get all matched positions
|
# get all matched positions
|
||||||
match_pos = []
|
match_pos = []
|
||||||
@ -118,7 +117,7 @@ def main():
|
|||||||
i += 1
|
i += 1
|
||||||
a = chars
|
a = chars
|
||||||
|
|
||||||
a = [a[j : j + n] for j in range(0, len(a), n)]
|
a = [a[j:j + n] for j in range(0, len(a), n)]
|
||||||
|
|
||||||
a_flat = []
|
a_flat = []
|
||||||
for z in a:
|
for z in a:
|
||||||
|
@ -40,6 +40,7 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -108,6 +109,16 @@ def get_parser():
|
|||||||
default=3,
|
default=3,
|
||||||
help="Maximum number of symbols per frame",
|
help="Maximum number of symbols per frame",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, the averaged model is saved to
|
||||||
|
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||||
|
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||||
|
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -417,6 +428,13 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
@ -427,7 +445,9 @@ def main():
|
|||||||
wenetspeech = WenetSpeechDataModule(args)
|
wenetspeech = WenetSpeechDataModule(args)
|
||||||
|
|
||||||
test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts())
|
test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts())
|
||||||
test_meetting_dl = wenetspeech.test_dataloaders(wenetspeech.test_meetting_cuts())
|
test_meetting_dl = wenetspeech.test_dataloaders(
|
||||||
|
wenetspeech.test_meetting_cuts()
|
||||||
|
)
|
||||||
|
|
||||||
test_sets = ["TEST_NET", "TEST_MEETTING"]
|
test_sets = ["TEST_NET", "TEST_MEETTING"]
|
||||||
test_dls = [test_net_dl, test_meetting_dl]
|
test_dls = [test_net_dl, test_meetting_dl]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user