From 6813b754e93ace0509bbee4cb3b81af02d879bbb Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Fri, 31 Dec 2021 19:11:24 +0800 Subject: [PATCH] fix style --- .../ASR/local/compute_fbank_wenetspeech.py | 7 ++++++- .../ASR/local/preprocess_wenetspeech.py | 7 ++++++- egs/wenetspeech/ASR/local/text2token.py | 18 ++++++++++-------- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 ++-- .../ASR/transducer_stateless/decode.py | 4 ++-- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech.py index 3fb32657f..9d1bfd6f7 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech.py @@ -128,7 +128,12 @@ def compute_fbank_wenetspeech(args): output_dir = Path("data/fbank") dataset_parts = ( - "L", "M", "S", "DEV", "TEST_NET", "TEST_MEETING", + "L", + "M", + "S", + "DEV", + "TEST_NET", + "TEST_MEETING", ) manifests = read_manifests_if_cached( dataset_parts=dataset_parts, diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index a5a5d5b23..3b977c868 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -49,7 +49,12 @@ def preprocess_wenet_speech(): output_dir.mkdir(exist_ok=True) dataset_parts = ( - "L", "M", "S", "DEV", "TEST_NET", "TEST_MEETING", + "L", + "M", + "S", + "DEV", + "TEST_NET", + "TEST_MEETING", ) logging.info("Loading manifest (may take 10 minutes)") diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py index 9140da6e8..8f8c433bf 100755 --- a/egs/wenetspeech/ASR/local/text2token.py +++ b/egs/wenetspeech/ASR/local/text2token.py @@ -40,8 +40,9 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default="", type=str, - help="space symbol") + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -49,8 +50,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", - help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -76,8 +78,8 @@ def main(): f = codecs.open(args.text, encoding="utf-8") else: f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer @@ -87,7 +89,7 @@ def main(): while line: x = line.split() print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols:]) + a = " ".join(x[args.skip_ncols :]) # get all matched positions match_pos = [] @@ -117,7 +119,7 @@ def main(): i += 1 a = chars - a = [a[j:j + n] for j in range(0, len(a), n)] + a = [a[j : j + n] for j in range(0, len(a), n)] a_flat = [] for z in a: diff --git a/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index cd08d90b1..4de55b32e 100644 --- a/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/wenetspeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -370,5 +370,5 @@ class WenetSpeechDataModule: def test_meetting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETTING cuts") return load_manifest( - self.args.manifest_dir / "cuts_TEST_MEETTING.jsonl.gz" - ) + self.args.manifest_dir / "cuts_TEST_MEETTING.jsonl.gz" + ) diff --git a/egs/wenetspeech/ASR/transducer_stateless/decode.py b/egs/wenetspeech/ASR/transducer_stateless/decode.py index 5976a8f0d..574591122 100755 --- a/egs/wenetspeech/ASR/transducer_stateless/decode.py +++ b/egs/wenetspeech/ASR/transducer_stateless/decode.py @@ -446,8 +446,8 @@ def main(): test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts()) test_meetting_dl = wenetspeech.test_dataloaders( - wenetspeech.test_meetting_cuts() - ) + wenetspeech.test_meetting_cuts() + ) test_sets = ["TEST_NET", "TEST_MEETTING"] test_dls = [test_net_dl, test_meetting_dl]