diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py index 15267b0cb..759d9d50a 100755 --- a/egs/libritts/ASR/zipformer/decode.py +++ b/egs/libritts/ASR/zipformer/decode.py @@ -123,7 +123,7 @@ from beam_search import ( modified_beam_search_LODR, ) from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( @@ -1043,8 +1043,8 @@ def main(): args.return_cuts = True libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = libritts.test_clean_cuts() - test_other_cuts = libritts.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) test_clean_dl = libritts.test_dataloaders(test_clean_cuts) test_other_dl = libritts.test_dataloaders(test_other_cuts) diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 0fa32d7f6..5485eaf0a 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -603,13 +603,18 @@ def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s +def normalize_text(c: Cut): + def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c def get_encoder_embed(params: AttributeDict) -> nn.Module: @@ -1309,11 +1314,6 @@ def run(rank, world_size, args): else: train_cuts = libritts.train_clean_100_cuts() - def normalize_text(c: Cut): - text = remove_punc_to_upper(c.supervisions[0].text) - c.supervisions[0].text = text - return c - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1365,8 +1365,8 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = libritts.dev_clean_cuts() - valid_cuts += libritts.dev_other_cuts() + valid_cuts = libritts.dev_clean_cuts().map(normalize_text) + valid_cuts += libritts.dev_other_cuts().map(normalize_text) valid_dl = libritts.valid_dataloaders(valid_cuts) if not params.print_diagnostics: