mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
applied text norm to valid & test cuts
This commit is contained in:
parent
f0744877a6
commit
156af46a6e
@ -123,7 +123,7 @@ from beam_search import (
|
|||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from lhotse import set_caching_enabled
|
from lhotse import set_caching_enabled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params, normalize_text
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -1043,8 +1043,8 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
libritts = LibriTTSAsrDataModule(args)
|
libritts = LibriTTSAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = libritts.test_clean_cuts()
|
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
|
||||||
test_other_cuts = libritts.test_other_cuts()
|
test_other_cuts = libritts.test_other_cuts().map(normalize_text)
|
||||||
|
|
||||||
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
|
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = libritts.test_dataloaders(test_other_cuts)
|
test_other_dl = libritts.test_dataloaders(test_other_cuts)
|
||||||
|
@ -603,7 +603,8 @@ def _to_int_tuple(s: str):
|
|||||||
return tuple(map(int, s.split(",")))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
|
|
||||||
def remove_punc_to_upper(text: str) -> str:
|
def normalize_text(c: Cut):
|
||||||
|
def remove_punc_to_upper(text: str) -> str:
|
||||||
text = text.replace("‘", "'")
|
text = text.replace("‘", "'")
|
||||||
text = text.replace("’", "'")
|
text = text.replace("’", "'")
|
||||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||||
@ -611,6 +612,10 @@ def remove_punc_to_upper(text: str) -> str:
|
|||||||
s = " ".join("".join(s_list).split()).strip()
|
s = " ".join("".join(s_list).split()).strip()
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||||
|
c.supervisions[0].text = text
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
# encoder_embed converts the input of shape (N, T, num_features)
|
# encoder_embed converts the input of shape (N, T, num_features)
|
||||||
@ -1309,11 +1314,6 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
train_cuts = libritts.train_clean_100_cuts()
|
train_cuts = libritts.train_clean_100_cuts()
|
||||||
|
|
||||||
def normalize_text(c: Cut):
|
|
||||||
text = remove_punc_to_upper(c.supervisions[0].text)
|
|
||||||
c.supervisions[0].text = text
|
|
||||||
return c
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
#
|
#
|
||||||
@ -1365,8 +1365,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = libritts.dev_clean_cuts()
|
valid_cuts = libritts.dev_clean_cuts().map(normalize_text)
|
||||||
valid_cuts += libritts.dev_other_cuts()
|
valid_cuts += libritts.dev_other_cuts().map(normalize_text)
|
||||||
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user