From 19f88482be322068c2a5fc8285b0fc780471e157 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 21 Jun 2024 18:05:52 +0800 Subject: [PATCH] merge with master --- egs/wenetspeech/ASR/zipformer/decode.py | 12 +--- egs/wenetspeech/ASR/zipformer/train.py | 73 +++++++++++++------------ requirements-ci.txt | 1 + requirements.txt | 1 + test/test_utils.py | 29 ++++++++++ 5 files changed, 73 insertions(+), 43 deletions(-) diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py index 0fbc8244b..56a4014d9 100755 --- a/egs/wenetspeech/ASR/zipformer/decode.py +++ b/egs/wenetspeech/ASR/zipformer/decode.py @@ -556,18 +556,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -577,9 +573,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 3d3762916..594192211 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -239,12 +239,22 @@ def add_model_arguments(parser: argparse.ArgumentParser): chunk left-context frames will be chosen randomly from this list; else not relevant.""", ) - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", ) + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--world-size", type=int, @@ -302,16 +312,6 @@ def get_parser(): """, ) - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - parser.add_argument( "--base-lr", type=float, default=0.045, help="The base learning rate." ) @@ -379,6 +379,13 @@ def get_parser(): with this parameter before adding to the final loss.""", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + parser.add_argument( "--seed", type=int, @@ -444,6 +451,24 @@ def get_parser(): help="Whether to use half precision training.", ) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + add_training_arguments(parser) + add_model_arguments(parser) return parser @@ -1154,26 +1179,6 @@ def run(rank, world_size, args): # ) return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/requirements-ci.txt b/requirements-ci.txt index ebea04615..713f7ec8a 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -28,5 +28,6 @@ multi_quantization onnx onnxmltools onnxruntime +pypinyin kaldifst kaldi-decoder diff --git a/requirements.txt b/requirements.txt index 226adaba1..79e00ffbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ kaldilm kaldialign num2words kaldi-decoder +pypinyin sentencepiece>=0.1.96 pypinyin==0.50.0 tensorboard diff --git a/test/test_utils.py b/test/test_utils.py index 31f06bd51..8e1abe279 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -28,6 +28,7 @@ from icefall.utils import ( encode_supervisions, get_texts, make_pad_mask, + text_to_pinyin, ) @@ -163,3 +164,31 @@ def test_add_eos(): [[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] ) assert str(ragged_eos) == str(expected) + + +def test_text_to_pinyin(): + txt = "想吃KFC" + + r = text_to_pinyin(txt, mode="full_with_tone") + assert " ".join(r) == "xiǎng chī KFC" + + r = text_to_pinyin(txt, mode="full_with_tone", errors="split") + assert " ".join(r) == "xiǎng chī K F C" + + r = text_to_pinyin(txt, mode="full_no_tone", errors="default") + assert " ".join(r) == "xiang chi KFC" + + r = text_to_pinyin(txt, mode="full_no_tone", errors="split") + assert " ".join(r) == "xiang chi K F C" + + r = text_to_pinyin(txt, mode="partial_with_tone") + assert " ".join(r) == "x iǎng ch ī KFC" + + r = text_to_pinyin(txt, mode="partial_with_tone", errors="split") + assert " ".join(r) == "x iǎng ch ī K F C" + + r = text_to_pinyin(txt, mode="partial_no_tone", errors="default") + assert " ".join(r) == "x iang ch i KFC" + + r = text_to_pinyin(txt, mode="partial_no_tone", errors="split") + assert " ".join(r) == "x iang ch i K F C"