fix style

This commit is contained in:
PingFeng Luo 2021-12-31 19:11:24 +08:00
parent 46eab8b727
commit 6813b754e9
5 changed files with 26 additions and 14 deletions

View File

@ -128,7 +128,12 @@ def compute_fbank_wenetspeech(args):
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
dataset_parts = ( dataset_parts = (
"L", "M", "S", "DEV", "TEST_NET", "TEST_MEETING", "L",
"M",
"S",
"DEV",
"TEST_NET",
"TEST_MEETING",
) )
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,

View File

@ -49,7 +49,12 @@ def preprocess_wenet_speech():
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
dataset_parts = ( dataset_parts = (
"L", "M", "S", "DEV", "TEST_NET", "TEST_MEETING", "L",
"M",
"S",
"DEV",
"TEST_NET",
"TEST_MEETING",
) )
logging.info("Loading manifest (may take 10 minutes)") logging.info("Loading manifest (may take 10 minutes)")

View File

@ -40,8 +40,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns" "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
) )
parser.add_argument("--space", default="<space>", type=str, parser.add_argument(
help="space symbol") "--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument( parser.add_argument(
"--non-lang-syms", "--non-lang-syms",
"-l", "-l",
@ -49,8 +50,9 @@ def get_parser():
type=str, type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", help="list of non-linguistic symobles, e.g., <NOISE> etc.",
) )
parser.add_argument("text", type=str, default=False, nargs="?", parser.add_argument(
help="input text") "text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument( parser.add_argument(
"--trans_type", "--trans_type",
"-t", "-t",
@ -76,8 +78,8 @@ def main():
f = codecs.open(args.text, encoding="utf-8") f = codecs.open(args.text, encoding="utf-8")
else: else:
f = codecs.getreader("utf-8")( f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer sys.stdin if is_python2 else sys.stdin.buffer
) )
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer sys.stdout if is_python2 else sys.stdout.buffer
@ -87,7 +89,7 @@ def main():
while line: while line:
x = line.split() x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ") print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols:]) a = " ".join(x[args.skip_ncols :])
# get all matched positions # get all matched positions
match_pos = [] match_pos = []
@ -117,7 +119,7 @@ def main():
i += 1 i += 1
a = chars a = chars
a = [a[j:j + n] for j in range(0, len(a), n)] a = [a[j : j + n] for j in range(0, len(a), n)]
a_flat = [] a_flat = []
for z in a: for z in a:

View File

@ -370,5 +370,5 @@ class WenetSpeechDataModule:
def test_meetting_cuts(self) -> List[CutSet]: def test_meetting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETTING cuts") logging.info("About to get TEST_MEETTING cuts")
return load_manifest( return load_manifest(
self.args.manifest_dir / "cuts_TEST_MEETTING.jsonl.gz" self.args.manifest_dir / "cuts_TEST_MEETTING.jsonl.gz"
) )

View File

@ -446,8 +446,8 @@ def main():
test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts()) test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts())
test_meetting_dl = wenetspeech.test_dataloaders( test_meetting_dl = wenetspeech.test_dataloaders(
wenetspeech.test_meetting_cuts() wenetspeech.test_meetting_cuts()
) )
test_sets = ["TEST_NET", "TEST_MEETTING"] test_sets = ["TEST_NET", "TEST_MEETTING"]
test_dls = [test_net_dl, test_meetting_dl] test_dls = [test_net_dl, test_meetting_dl]