refactor code

This commit is contained in:
marcoyang 2023-10-11 10:54:40 +08:00
parent 77fc1c0929
commit 51d9c4f028
2 changed files with 25 additions and 23 deletions

View File

@ -114,6 +114,7 @@ from beam_search import greedy_search, greedy_search_batch, modified_beam_search
from dataset import naive_triplet_text_sampling, random_shuffle_subset from dataset import naive_triplet_text_sampling, random_shuffle_subset
from ls_text_normalization import word_normalization from ls_text_normalization import word_normalization
from text_normalization import ( from text_normalization import (
_apply_style_transform,
lower_all_char, lower_all_char,
lower_only_alpha, lower_only_alpha,
ref_text_normalization, ref_text_normalization,
@ -387,29 +388,6 @@ def get_parser():
return parser return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import List
def train_text_normalization(s: str) -> str: def train_text_normalization(s: str) -> str:
@ -70,6 +71,29 @@ def upper_all_char(text: str) -> str:
return text.upper() return text.upper()
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
if __name__ == "__main__": if __name__ == "__main__":
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
print(ref_text) print(ref_text)