mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
refactor code
This commit is contained in:
parent
77fc1c0929
commit
51d9c4f028
@ -114,6 +114,7 @@ from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import (
|
||||
_apply_style_transform,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
@ -387,29 +388,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def train_text_normalization(s: str) -> str:
|
||||
@ -70,6 +71,29 @@ def upper_all_char(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
print(ref_text)
|
||||
|
Loading…
x
Reference in New Issue
Block a user