minor updates

This commit is contained in:
JinZr 2024-10-21 17:10:40 +08:00
parent dc0106a0d5
commit 8da9acd7e1
3 changed files with 170 additions and 55 deletions

View File

@ -31,6 +31,14 @@ from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s
def prepare_tokens_libritts():
output_dir = Path("data/spectrogram")
prefix = "libritts"
@ -60,6 +68,8 @@ def prepare_tokens_libritts():
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -21,7 +21,7 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import numpy as np
@ -29,6 +29,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
@ -331,16 +332,22 @@ def prepare_input(
batch: dict,
tokenizer: Tokenizer,
device: torch.device,
speaker_map: Dict[str, int],
speaker_map: KaldiReader,
):
"""Parse batch data"""
def parse_sids(batch: dict) -> List[str]:
return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]]
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
.squeeze(1)
.to(device)
)
tokens = tokenizer.tokens_to_token_ids(
@ -366,8 +373,9 @@ def train_one_epoch(
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
dev_dl: torch.utils.data.DataLoader,
train_speaker_map: KaldiReader,
dev_speaker_map: KaldiReader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -442,7 +450,7 @@ def train_one_epoch(
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, train_speaker_map)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -457,7 +465,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
for k, v in stats_d.items():
@ -476,7 +484,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
@ -583,8 +591,8 @@ def train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
speaker_map=speaker_map,
dev_dl=dev_dl,
dev_speaker_map=dev_speaker_map,
world_size=world_size,
)
model.train()
@ -620,8 +628,8 @@ def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
dev_dl: torch.utils.data.DataLoader,
dev_speaker_map: KaldiReader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
@ -634,7 +642,7 @@ def compute_validation_loss(
returned_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
for batch_idx, batch in enumerate(dev_dl):
batch_size = len(batch["tokens"])
(
audio,
@ -644,7 +652,7 @@ def compute_validation_loss(
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -657,7 +665,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
assert loss_d.requires_grad is False
@ -672,7 +680,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
)
assert loss_g.requires_grad is False
@ -687,7 +695,7 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()],
sids=speakers[0],
spembs=speakers[0],
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
speaker_map: Dict[str, int],
train_speaker_map: KaldiReader,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom(
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
) = prepare_input(batch, tokenizer, device, train_speaker_map)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=False,
)
optimizer_d.zero_grad()
@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sids=speakers,
spembs=speakers,
forward_generator=True,
)
optimizer_g.zero_grad()
@ -820,9 +828,12 @@ def run(rank, world_size, args):
libritts = LibrittsTtsDataModule(args)
train_cuts = libritts.train_cuts()
speaker_map = libritts.speakers()
params.num_spks = len(speaker_map)
if params.full_libri:
train_cuts = libritts.train_all_shuf_cuts()
train_speaker_map = libritts.train_all_shuf_xvector()
else:
train_cuts = libritts.train_clean_460_cuts()
train_speaker_map = libritts.train_clean_460_xvector()
logging.info(params)
@ -896,8 +907,9 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = libritts.train_dataloaders(train_cuts)
valid_cuts = libritts.valid_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts)
dev_clean_cuts = libritts.dev_clean_cuts()
dev_speaker_map = libritts.dev_clean_xvector()
dev_dl = libritts.dev_dataloaders(dev_clean_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -906,7 +918,7 @@ def run(rank, world_size, args):
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
speaker_map=speaker_map,
train_speaker_map=train_speaker_map,
params=params,
)
@ -935,8 +947,9 @@ def run(rank, world_size, args):
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
speaker_map=speaker_map,
dev_dl=dev_dl,
train_speaker_map=train_speaker_map,
dev_speaker_map=dev_speaker_map,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,

View File

@ -38,6 +38,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -51,8 +52,10 @@ class _SeedWorkers:
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
LIBRITTS_SAMPLING_RATE = 24000
class LibrittsTtsDataModule:
"""
DataModule for tts experiments.
@ -82,7 +85,13 @@ class LibrittsTtsDataModule:
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=False,
help="""When enabled, use the entire LibriTTS training set.
Otherwise, use the 460h clean subset.""",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -90,10 +99,10 @@ class LibrittsTtsDataModule:
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--speakers",
"--speaker-embeds",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
default=Path("exp/xvector_nnet_1a/"),
help="Path to directory with speaker embeddings.",
)
group.add_argument(
"--max-duration",
@ -141,7 +150,7 @@ class LibrittsTtsDataModule:
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
@ -175,7 +184,7 @@ class LibrittsTtsDataModule:
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
@ -191,7 +200,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -238,7 +247,7 @@ class LibrittsTtsDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = LIBRITTS_SAMPLING_RATE
@ -249,7 +258,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -257,7 +266,7 @@ class LibrittsTtsDataModule:
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
@ -290,7 +299,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -298,7 +307,7 @@ class LibrittsTtsDataModule:
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
@ -319,23 +328,106 @@ class LibrittsTtsDataModule:
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
def train_clean_460_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100 and train-clean-360 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir
/ "libritts_cuts_with_tokens_train-clean-460.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
)
@lru_cache()
def speakers(self) -> Dict[str, int]:
logging.info("About to get speakers")
with open(self.args.speakers) as f:
speakers = {line.strip(): i for i, line in enumerate(f)}
return speakers
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz"
)
@lru_cache()
def train_all_shuf_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_clean_460_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for train-clean-460")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp")
)
@lru_cache()
def train_clean_100_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_clean_360_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_other_500_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def dev_clean_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for dev-clean")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp")
)
@lru_cache()
def dev_other_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def test_clean_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for test-clean")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp")
)
@lru_cache()
def test_other_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)