From 51d9c4f028b3eb5e8e909670ff94b62c2e5f62bf Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 11 Oct 2023 10:54:40 +0800 Subject: [PATCH] refactor code --- .../ASR/zipformer_prompt_asr/decode_bert.py | 24 +------------------ .../text_normalization.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py index cc6d4f654..e71999b0a 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py @@ -114,6 +114,7 @@ from beam_search import greedy_search, greedy_search_batch, modified_beam_search from dataset import naive_triplet_text_sampling, random_shuffle_subset from ls_text_normalization import word_normalization from text_normalization import ( + _apply_style_transform, lower_all_char, lower_only_alpha, ref_text_normalization, @@ -387,29 +388,6 @@ def get_parser(): return parser -def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in - ground truth format, i.e mixed-punc. - - Args: - text (List[str]): Input text string - transform (str): Transform to be applied - - Returns: - List[str]: _description_ - """ - if transform == "mixed-punc": - return text - elif transform == "upper-no-punc": - return [upper_only_alpha(s) for s in text] - elif transform == "lower-no-punc": - return [lower_only_alpha(s) for s in text] - elif transform == "lower-punc": - return [lower_all_char(s) for s in text] - else: - raise NotImplementedError(f"Unseen transform: {transform}") - - def decode_one_batch( params: AttributeDict, model: nn.Module, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py index 657089f46..efb4acc3c 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py @@ -15,6 +15,7 @@ # limitations under the License. import re +from typing import List def train_text_normalization(s: str) -> str: @@ -70,6 +71,29 @@ def upper_all_char(text: str) -> str: return text.upper() +def _apply_style_transform(text: List[str], transform: str) -> List[str]: + """Apply transform to a list of text. By default, the text are in + ground truth format, i.e mixed-punc. + + Args: + text (List[str]): Input text string + transform (str): Transform to be applied + + Returns: + List[str]: _description_ + """ + if transform == "mixed-punc": + return text + elif transform == "upper-no-punc": + return [upper_only_alpha(s) for s in text] + elif transform == "lower-no-punc": + return [lower_only_alpha(s) for s in text] + elif transform == "lower-punc": + return [lower_all_char(s) for s in text] + else: + raise NotImplementedError(f"Unseen transform: {transform}") + + if __name__ == "__main__": ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." print(ref_text)