mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
refactor code
This commit is contained in:
parent
77fc1c0929
commit
51d9c4f028
@ -114,6 +114,7 @@ from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
|||||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||||
from ls_text_normalization import word_normalization
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import (
|
from text_normalization import (
|
||||||
|
_apply_style_transform,
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
ref_text_normalization,
|
ref_text_normalization,
|
||||||
@ -387,29 +388,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
|
||||||
"""Apply transform to a list of text. By default, the text are in
|
|
||||||
ground truth format, i.e mixed-punc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (List[str]): Input text string
|
|
||||||
transform (str): Transform to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: _description_
|
|
||||||
"""
|
|
||||||
if transform == "mixed-punc":
|
|
||||||
return text
|
|
||||||
elif transform == "upper-no-punc":
|
|
||||||
return [upper_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-no-punc":
|
|
||||||
return [lower_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-punc":
|
|
||||||
return [lower_all_char(s) for s in text]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
def train_text_normalization(s: str) -> str:
|
def train_text_normalization(s: str) -> str:
|
||||||
@ -70,6 +71,29 @@ def upper_all_char(text: str) -> str:
|
|||||||
return text.upper()
|
return text.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||||
|
"""Apply transform to a list of text. By default, the text are in
|
||||||
|
ground truth format, i.e mixed-punc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (List[str]): Input text string
|
||||||
|
transform (str): Transform to be applied
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: _description_
|
||||||
|
"""
|
||||||
|
if transform == "mixed-punc":
|
||||||
|
return text
|
||||||
|
elif transform == "upper-no-punc":
|
||||||
|
return [upper_only_alpha(s) for s in text]
|
||||||
|
elif transform == "lower-no-punc":
|
||||||
|
return [lower_only_alpha(s) for s in text]
|
||||||
|
elif transform == "lower-punc":
|
||||||
|
return [lower_all_char(s) for s in text]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||||
print(ref_text)
|
print(ref_text)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user