applied text norm to valid & test cuts

This commit is contained in:
JinZr 2024-10-08 00:02:16 +08:00
parent f0744877a6
commit 156af46a6e
2 changed files with 17 additions and 17 deletions

View File

@ -123,7 +123,7 @@ from beam_search import (
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from lhotse import set_caching_enabled from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params, normalize_text
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -1043,8 +1043,8 @@ def main():
args.return_cuts = True args.return_cuts = True
libritts = LibriTTSAsrDataModule(args) libritts = LibriTTSAsrDataModule(args)
test_clean_cuts = libritts.test_clean_cuts() test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
test_other_cuts = libritts.test_other_cuts() test_other_cuts = libritts.test_other_cuts().map(normalize_text)
test_clean_dl = libritts.test_dataloaders(test_clean_cuts) test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
test_other_dl = libritts.test_dataloaders(test_other_cuts) test_other_dl = libritts.test_dataloaders(test_other_cuts)

View File

@ -603,7 +603,8 @@ def _to_int_tuple(s: str):
return tuple(map(int, s.split(","))) return tuple(map(int, s.split(",")))
def remove_punc_to_upper(text: str) -> str: def normalize_text(c: Cut):
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'") text = text.replace("", "'")
text = text.replace("", "'") text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
@ -611,6 +612,10 @@ def remove_punc_to_upper(text: str) -> str:
s = " ".join("".join(s_list).split()).strip() s = " ".join("".join(s_list).split()).strip()
return s return s
text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c
def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_embed(params: AttributeDict) -> nn.Module:
# encoder_embed converts the input of shape (N, T, num_features) # encoder_embed converts the input of shape (N, T, num_features)
@ -1309,11 +1314,6 @@ def run(rank, world_size, args):
else: else:
train_cuts = libritts.train_clean_100_cuts() train_cuts = libritts.train_clean_100_cuts()
def normalize_text(c: Cut):
text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
# #
@ -1365,8 +1365,8 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = libritts.dev_clean_cuts() valid_cuts = libritts.dev_clean_cuts().map(normalize_text)
valid_cuts += libritts.dev_other_cuts() valid_cuts += libritts.dev_other_cuts().map(normalize_text)
valid_dl = libritts.valid_dataloaders(valid_cuts) valid_dl = libritts.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics: