mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
minor updates
This commit is contained in:
parent
dc0106a0d5
commit
8da9acd7e1
@ -31,6 +31,14 @@ from piper_phonemize import phonemize_espeak
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punc_to_upper(text: str) -> str:
|
||||||
|
text = text.replace("‘", "'")
|
||||||
|
text = text.replace("’", "'")
|
||||||
|
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||||
|
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||||
|
s = " ".join("".join(s_list).split()).strip()
|
||||||
|
return s
|
||||||
|
|
||||||
def prepare_tokens_libritts():
|
def prepare_tokens_libritts():
|
||||||
output_dir = Path("data/spectrogram")
|
output_dir = Path("data/spectrogram")
|
||||||
prefix = "libritts"
|
prefix = "libritts"
|
||||||
@ -60,6 +68,8 @@ def prepare_tokens_libritts():
|
|||||||
for t in tokens_list:
|
for t in tokens_list:
|
||||||
tokens.extend(t)
|
tokens.extend(t)
|
||||||
cut.tokens = tokens
|
cut.tokens = tokens
|
||||||
|
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
|
||||||
|
|
||||||
new_cuts.append(cut)
|
new_cuts.append(cut)
|
||||||
|
|
||||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
@ -21,7 +21,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -29,6 +29,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.features.io import KaldiReader
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
@ -331,16 +332,22 @@ def prepare_input(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
speaker_map: Dict[str, int],
|
speaker_map: KaldiReader,
|
||||||
):
|
):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
|
|
||||||
|
def parse_sids(batch: dict) -> List[str]:
|
||||||
|
return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]]
|
||||||
|
|
||||||
audio = batch["audio"].to(device)
|
audio = batch["audio"].to(device)
|
||||||
features = batch["features"].to(device)
|
features = batch["features"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
speakers = (
|
speakers = (
|
||||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
|
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
|
||||||
|
.squeeze(1)
|
||||||
|
.to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
@ -366,8 +373,9 @@ def train_one_epoch(
|
|||||||
scheduler_g: LRSchedulerType,
|
scheduler_g: LRSchedulerType,
|
||||||
scheduler_d: LRSchedulerType,
|
scheduler_d: LRSchedulerType,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
dev_dl: torch.utils.data.DataLoader,
|
||||||
speaker_map: Dict[str, int],
|
train_speaker_map: KaldiReader,
|
||||||
|
dev_speaker_map: KaldiReader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -442,7 +450,7 @@ def train_one_epoch(
|
|||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
speakers,
|
||||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -457,7 +465,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
@ -476,7 +484,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
@ -583,8 +591,8 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_dl=valid_dl,
|
dev_dl=dev_dl,
|
||||||
speaker_map=speaker_map,
|
dev_speaker_map=dev_speaker_map,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
@ -620,8 +628,8 @@ def compute_validation_loss(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
dev_dl: torch.utils.data.DataLoader,
|
||||||
speaker_map: Dict[str, int],
|
dev_speaker_map: KaldiReader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
@ -634,7 +642,7 @@ def compute_validation_loss(
|
|||||||
returned_sample = None
|
returned_sample = None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(dev_dl):
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
@ -644,7 +652,7 @@ def compute_validation_loss(
|
|||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
speakers,
|
||||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -657,7 +665,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
assert loss_d.requires_grad is False
|
assert loss_d.requires_grad is False
|
||||||
@ -672,7 +680,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
@ -687,7 +695,7 @@ def compute_validation_loss(
|
|||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred, _, duration = inner_model.inference(
|
audio_pred, _, duration = inner_model.inference(
|
||||||
text=tokens[0, : tokens_lens[0].item()],
|
text=tokens[0, : tokens_lens[0].item()],
|
||||||
sids=speakers[0],
|
spembs=speakers[0],
|
||||||
)
|
)
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
audio_len_pred = (
|
audio_len_pred = (
|
||||||
@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: torch.optim.Optimizer,
|
optimizer_g: torch.optim.Optimizer,
|
||||||
optimizer_d: torch.optim.Optimizer,
|
optimizer_d: torch.optim.Optimizer,
|
||||||
speaker_map: Dict[str, int],
|
train_speaker_map: KaldiReader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
speakers,
|
||||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
sids=speakers,
|
spembs=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
@ -820,9 +828,12 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
libritts = LibrittsTtsDataModule(args)
|
libritts = LibrittsTtsDataModule(args)
|
||||||
|
|
||||||
train_cuts = libritts.train_cuts()
|
if params.full_libri:
|
||||||
speaker_map = libritts.speakers()
|
train_cuts = libritts.train_all_shuf_cuts()
|
||||||
params.num_spks = len(speaker_map)
|
train_speaker_map = libritts.train_all_shuf_xvector()
|
||||||
|
else:
|
||||||
|
train_cuts = libritts.train_clean_460_cuts()
|
||||||
|
train_speaker_map = libritts.train_clean_460_xvector()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -896,8 +907,9 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
train_dl = libritts.train_dataloaders(train_cuts)
|
train_dl = libritts.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = libritts.valid_cuts()
|
dev_clean_cuts = libritts.dev_clean_cuts()
|
||||||
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
dev_speaker_map = libritts.dev_clean_xvector()
|
||||||
|
dev_dl = libritts.dev_dataloaders(dev_clean_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -906,7 +918,7 @@ def run(rank, world_size, args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
speaker_map=speaker_map,
|
train_speaker_map=train_speaker_map,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -935,8 +947,9 @@ def run(rank, world_size, args):
|
|||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
scheduler_d=scheduler_d,
|
scheduler_d=scheduler_d,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
dev_dl=dev_dl,
|
||||||
speaker_map=speaker_map,
|
train_speaker_map=train_speaker_map,
|
||||||
|
dev_speaker_map=dev_speaker_map,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
@ -38,6 +38,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|||||||
AudioSamples,
|
AudioSamples,
|
||||||
OnTheFlyFeatures,
|
OnTheFlyFeatures,
|
||||||
)
|
)
|
||||||
|
from lhotse.features.io import KaldiReader
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -51,8 +52,10 @@ class _SeedWorkers:
|
|||||||
def __call__(self, worker_id: int):
|
def __call__(self, worker_id: int):
|
||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
LIBRITTS_SAMPLING_RATE = 24000
|
LIBRITTS_SAMPLING_RATE = 24000
|
||||||
|
|
||||||
|
|
||||||
class LibrittsTtsDataModule:
|
class LibrittsTtsDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for tts experiments.
|
DataModule for tts experiments.
|
||||||
@ -82,7 +85,13 @@ class LibrittsTtsDataModule:
|
|||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc.",
|
"augmentations, etc.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--full-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, use the entire LibriTTS training set.
|
||||||
|
Otherwise, use the 460h clean subset.""",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -90,10 +99,10 @@ class LibrittsTtsDataModule:
|
|||||||
help="Path to directory with train/valid/test cuts.",
|
help="Path to directory with train/valid/test cuts.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--speakers",
|
"--speaker-embeds",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("data/speakers.txt"),
|
default=Path("exp/xvector_nnet_1a/"),
|
||||||
help="Path to speakers.txt file.",
|
help="Path to directory with speaker embeddings.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--max-duration",
|
"--max-duration",
|
||||||
@ -141,7 +150,7 @@ class LibrittsTtsDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="When enabled, each batch will have the "
|
help="When enabled, each batch will have the "
|
||||||
"field: batch['cut'] with the cuts that "
|
"field: batch['cut'] with the cuts that "
|
||||||
"were used to construct it.",
|
"were used to construct it.",
|
||||||
@ -175,7 +184,7 @@ class LibrittsTtsDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
@ -191,7 +200,7 @@ class LibrittsTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
@ -238,7 +247,7 @@ class LibrittsTtsDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = LIBRITTS_SAMPLING_RATE
|
sampling_rate = LIBRITTS_SAMPLING_RATE
|
||||||
@ -249,7 +258,7 @@ class LibrittsTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
@ -257,7 +266,7 @@ class LibrittsTtsDataModule:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
@ -290,7 +299,7 @@ class LibrittsTtsDataModule:
|
|||||||
use_fft_mag=True,
|
use_fft_mag=True,
|
||||||
)
|
)
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||||
@ -298,7 +307,7 @@ class LibrittsTtsDataModule:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=True,
|
||||||
return_spk_ids=True,
|
return_spk_ids=True,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
@ -319,23 +328,106 @@ class LibrittsTtsDataModule:
|
|||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_all_shuf_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info(
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
"About to get the shuffled train-clean-100, \
|
||||||
|
train-clean-360 and train-other-500 cuts"
|
||||||
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def train_clean_460_cuts(self) -> CutSet:
|
||||||
logging.info("About to get validation cuts")
|
logging.info(
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
"About to get the shuffled train-clean-100 and train-clean-360 cuts"
|
||||||
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ "libritts_cuts_with_tokens_train-clean-460.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get dev-clean cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def speakers(self) -> Dict[str, int]:
|
def dev_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get speakers")
|
logging.info("About to get dev-other cuts")
|
||||||
with open(self.args.speakers) as f:
|
return load_manifest_lazy(
|
||||||
speakers = {line.strip(): i for i, line in enumerate(f)}
|
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz"
|
||||||
return speakers
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_all_shuf_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_460_xvector(self) -> KaldiReader:
|
||||||
|
logging.info("About to get speaker embeddings for train-clean-460")
|
||||||
|
return KaldiReader(
|
||||||
|
str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp")
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_100_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_360_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_other_500_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_xvector(self) -> KaldiReader:
|
||||||
|
logging.info("About to get speaker embeddings for dev-clean")
|
||||||
|
return KaldiReader(
|
||||||
|
str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp")
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_other_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_xvector(self) -> KaldiReader:
|
||||||
|
logging.info("About to get speaker embeddings for test-clean")
|
||||||
|
return KaldiReader(
|
||||||
|
str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp")
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_xvector(self) -> KaldiReader:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Please implement the method to load speaker embeddings."
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user