From 28d1e8660ebaf752fa7caec46b0201456ed39fac Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Fri, 31 Dec 2021 18:50:12 +0800 Subject: [PATCH] fix sytle --- .../ASR/local/make_syball_lexicon.py | 43 ------------------- .../ASR/local/preprocess_wenetspeech.py | 2 +- egs/wenetspeech/ASR/local/text2token.py | 21 +++++---- .../ASR/transducer_stateless/decode.py | 22 +++++++++- 4 files changed, 32 insertions(+), 56 deletions(-) delete mode 100755 egs/wenetspeech/ASR/local/make_syball_lexicon.py diff --git a/egs/wenetspeech/ASR/local/make_syball_lexicon.py b/egs/wenetspeech/ASR/local/make_syball_lexicon.py deleted file mode 100755 index 9ab099842..000000000 --- a/egs/wenetspeech/ASR/local/make_syball_lexicon.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Xiaomi Corporation (Author: Pingfeng Luo) -import argparse -import re -from pathlib import Path -from typing import Dict, List -from pypinyin import pinyin, lazy_pinyin, Style - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - return parser.parse_args() - - -def process_line( - line: str -) -> None: - """ - Args: - line: - A line of transcript consisting of space(s) separated words. - Returns: - Return None. - """ - char = line.strip().split()[0] - syllables = pinyin(char, style=Style.TONE3, heteronym=True) - syllables = ''.join(str(syllables[0][:])) - for s in syllables.split(',') : - print("{} {}".format(char, re.sub(r'[]', '', s))) - - -def main(): - args = get_args() - assert Path(args.lexicon).is_file() - - with open(args.lexicon) as f: - for line in f: - process_line(line=line) - - -if __name__ == "__main__": - main() diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 8bd3a79bc..a5a5d5b23 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -29,7 +29,7 @@ from lhotse.recipes.utils import read_manifests_if_cached def normalize_text( utt: str, - #punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + # punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), whitespace_pattern=re.compile(r"\s\s+"), ) -> str: diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py index 56c39138f..9140da6e8 100755 --- a/egs/wenetspeech/ASR/local/text2token.py +++ b/egs/wenetspeech/ASR/local/text2token.py @@ -40,7 +40,8 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument("--space", default="", type=str, + help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -48,19 +49,15 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument("text", type=str, default=False, nargs="?", + help="input text") parser.add_argument( "--trans_type", "-t", type=str, default="char", choices=["char", "phn"], - help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - - If trans_type is char, - read from SI1279.WRD file -> "bricks are an alternative" - Else if trans_type is phn, - read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l - sil t er n ih sil t ih v sil" """, + help="""Transcript type. char/phn""", ) return parser @@ -78,7 +75,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer @@ -88,7 +87,7 @@ def main(): while line: x = line.split() print(" ".join(x[: args.skip_ncols]), end=" ") - a = " ".join(x[args.skip_ncols :]) + a = " ".join(x[args.skip_ncols:]) # get all matched positions match_pos = [] @@ -118,7 +117,7 @@ def main(): i += 1 a = chars - a = [a[j : j + n] for j in range(0, len(a), n)] + a = [a[j:j + n] for j in range(0, len(a), n)] a_flat = [] for z in a: diff --git a/egs/wenetspeech/ASR/transducer_stateless/decode.py b/egs/wenetspeech/ASR/transducer_stateless/decode.py index ad0a73389..5976a8f0d 100755 --- a/egs/wenetspeech/ASR/transducer_stateless/decode.py +++ b/egs/wenetspeech/ASR/transducer_stateless/decode.py @@ -40,6 +40,7 @@ from icefall.utils import ( setup_logger, store_transcripts, write_error_stats, + str2bool, ) @@ -108,6 +109,16 @@ def get_parser(): default=3, help="Maximum number of symbols per frame", ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) return parser @@ -417,6 +428,13 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() model.device = device @@ -427,7 +445,9 @@ def main(): wenetspeech = WenetSpeechDataModule(args) test_net_dl = wenetspeech.test_dataloaders(wenetspeech.test_net_cuts()) - test_meetting_dl = wenetspeech.test_dataloaders(wenetspeech.test_meetting_cuts()) + test_meetting_dl = wenetspeech.test_dataloaders( + wenetspeech.test_meetting_cuts() + ) test_sets = ["TEST_NET", "TEST_MEETTING"] test_dls = [test_net_dl, test_meetting_dl]