minor updates

This commit is contained in:
JinZr 2024-10-21 17:10:40 +08:00
parent dc0106a0d5
commit 8da9acd7e1
3 changed files with 170 additions and 55 deletions

View File

@ -31,6 +31,14 @@ from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm from tqdm.auto import tqdm
def remove_punc_to_upper(text: str) -> str:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s
def prepare_tokens_libritts(): def prepare_tokens_libritts():
output_dir = Path("data/spectrogram") output_dir = Path("data/spectrogram")
prefix = "libritts" prefix = "libritts"
@ -60,6 +68,8 @@ def prepare_tokens_libritts():
for t in tokens_list: for t in tokens_list:
tokens.extend(t) tokens.extend(t)
cut.tokens = tokens cut.tokens = tokens
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
new_cuts.append(cut) new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts) new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -21,7 +21,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import numpy as np import numpy as np
@ -29,6 +29,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
@ -331,16 +332,22 @@ def prepare_input(
batch: dict, batch: dict,
tokenizer: Tokenizer, tokenizer: Tokenizer,
device: torch.device, device: torch.device,
speaker_map: Dict[str, int], speaker_map: KaldiReader,
): ):
"""Parse batch data""" """Parse batch data"""
def parse_sids(batch: dict) -> List[str]:
return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]]
audio = batch["audio"].to(device) audio = batch["audio"].to(device)
features = batch["features"].to(device) features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device) audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"] tokens = batch["tokens"]
speakers = ( speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
.squeeze(1)
.to(device)
) )
tokens = tokenizer.tokens_to_token_ids( tokens = tokenizer.tokens_to_token_ids(
@ -366,8 +373,9 @@ def train_one_epoch(
scheduler_g: LRSchedulerType, scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType, scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, dev_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int], train_speaker_map: KaldiReader,
dev_speaker_map: KaldiReader,
scaler: GradScaler, scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -442,7 +450,7 @@ def train_one_epoch(
tokens, tokens,
tokens_lens, tokens_lens,
speakers, speakers,
) = prepare_input(batch, tokenizer, device, speaker_map) ) = prepare_input(batch, tokenizer, device, train_speaker_map)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -457,7 +465,7 @@ def train_one_epoch(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=False, forward_generator=False,
) )
for k, v in stats_d.items(): for k, v in stats_d.items():
@ -476,7 +484,7 @@ def train_one_epoch(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=True, forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0, return_sample=params.batch_idx_train % params.log_interval == 0,
) )
@ -583,8 +591,8 @@ def train_one_epoch(
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
valid_dl=valid_dl, dev_dl=dev_dl,
speaker_map=speaker_map, dev_speaker_map=dev_speaker_map,
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
@ -620,8 +628,8 @@ def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
tokenizer: Tokenizer, tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader, dev_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int], dev_speaker_map: KaldiReader,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: ) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
@ -634,7 +642,7 @@ def compute_validation_loss(
returned_sample = None returned_sample = None
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(dev_dl):
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
( (
audio, audio,
@ -644,7 +652,7 @@ def compute_validation_loss(
tokens, tokens,
tokens_lens, tokens_lens,
speakers, speakers,
) = prepare_input(batch, tokenizer, device, speaker_map) ) = prepare_input(batch, tokenizer, device, dev_speaker_map)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -657,7 +665,7 @@ def compute_validation_loss(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=False, forward_generator=False,
) )
assert loss_d.requires_grad is False assert loss_d.requires_grad is False
@ -672,7 +680,7 @@ def compute_validation_loss(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=True, forward_generator=True,
) )
assert loss_g.requires_grad is False assert loss_g.requires_grad is False
@ -687,7 +695,7 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference( audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()], text=tokens[0, : tokens_lens[0].item()],
sids=speakers[0], spembs=speakers[0],
) )
audio_pred = audio_pred.data.cpu().numpy() audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = ( audio_len_pred = (
@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer, tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer, optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer,
speaker_map: Dict[str, int], train_speaker_map: KaldiReader,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom(
tokens, tokens,
tokens_lens, tokens_lens,
speakers, speakers,
) = prepare_input(batch, tokenizer, device, speaker_map) ) = prepare_input(batch, tokenizer, device, train_speaker_map)
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=False, forward_generator=False,
) )
optimizer_d.zero_grad() optimizer_d.zero_grad()
@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sids=speakers, spembs=speakers,
forward_generator=True, forward_generator=True,
) )
optimizer_g.zero_grad() optimizer_g.zero_grad()
@ -820,9 +828,12 @@ def run(rank, world_size, args):
libritts = LibrittsTtsDataModule(args) libritts = LibrittsTtsDataModule(args)
train_cuts = libritts.train_cuts() if params.full_libri:
speaker_map = libritts.speakers() train_cuts = libritts.train_all_shuf_cuts()
params.num_spks = len(speaker_map) train_speaker_map = libritts.train_all_shuf_xvector()
else:
train_cuts = libritts.train_clean_460_cuts()
train_speaker_map = libritts.train_clean_460_xvector()
logging.info(params) logging.info(params)
@ -896,8 +907,9 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = libritts.train_dataloaders(train_cuts) train_dl = libritts.train_dataloaders(train_cuts)
valid_cuts = libritts.valid_cuts() dev_clean_cuts = libritts.dev_clean_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts) dev_speaker_map = libritts.dev_clean_xvector()
dev_dl = libritts.dev_dataloaders(dev_clean_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -906,7 +918,7 @@ def run(rank, world_size, args):
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
speaker_map=speaker_map, train_speaker_map=train_speaker_map,
params=params, params=params,
) )
@ -935,8 +947,9 @@ def run(rank, world_size, args):
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
scheduler_d=scheduler_d, scheduler_d=scheduler_d,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, dev_dl=dev_dl,
speaker_map=speaker_map, train_speaker_map=train_speaker_map,
dev_speaker_map=dev_speaker_map,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,

View File

@ -38,6 +38,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
OnTheFlyFeatures, OnTheFlyFeatures,
) )
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -51,8 +52,10 @@ class _SeedWorkers:
def __call__(self, worker_id: int): def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id) fix_random_seed(self.seed + worker_id)
LIBRITTS_SAMPLING_RATE = 24000 LIBRITTS_SAMPLING_RATE = 24000
class LibrittsTtsDataModule: class LibrittsTtsDataModule:
""" """
DataModule for tts experiments. DataModule for tts experiments.
@ -82,7 +85,13 @@ class LibrittsTtsDataModule:
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc.",
) )
group.add_argument(
"--full-libri",
type=str2bool,
default=False,
help="""When enabled, use the entire LibriTTS training set.
Otherwise, use the 460h clean subset.""",
)
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
@ -90,10 +99,10 @@ class LibrittsTtsDataModule:
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument( group.add_argument(
"--speakers", "--speaker-embeds",
type=Path, type=Path,
default=Path("data/speakers.txt"), default=Path("exp/xvector_nnet_1a/"),
help="Path to speakers.txt file.", help="Path to directory with speaker embeddings.",
) )
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
@ -141,7 +150,7 @@ class LibrittsTtsDataModule:
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=False, default=True,
help="When enabled, each batch will have the " help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that " "field: batch['cut'] with the cuts that "
"were used to construct it.", "were used to construct it.",
@ -175,7 +184,7 @@ class LibrittsTtsDataModule:
""" """
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
@ -191,7 +200,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -238,7 +247,7 @@ class LibrittsTtsDataModule:
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = LIBRITTS_SAMPLING_RATE sampling_rate = LIBRITTS_SAMPLING_RATE
@ -249,7 +258,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -257,7 +266,7 @@ class LibrittsTtsDataModule:
) )
else: else:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
@ -290,7 +299,7 @@ class LibrittsTtsDataModule:
use_fft_mag=True, use_fft_mag=True,
) )
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
@ -298,7 +307,7 @@ class LibrittsTtsDataModule:
) )
else: else:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=True,
return_spk_ids=True, return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
@ -319,23 +328,106 @@ class LibrittsTtsDataModule:
return test_dl return test_dl
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_all_shuf_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info(
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") "About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def train_clean_460_cuts(self) -> CutSet:
logging.info("About to get validation cuts") logging.info(
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") "About to get the shuffled train-clean-100 and train-clean-360 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir
/ "libritts_cuts_with_tokens_train-clean-460.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def dev_clean_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get dev-clean cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
)
@lru_cache() @lru_cache()
def speakers(self) -> Dict[str, int]: def dev_other_cuts(self) -> CutSet:
logging.info("About to get speakers") logging.info("About to get dev-other cuts")
with open(self.args.speakers) as f: return load_manifest_lazy(
speakers = {line.strip(): i for i, line in enumerate(f)} self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz"
return speakers )
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz"
)
@lru_cache()
def train_all_shuf_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_clean_460_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for train-clean-460")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp")
)
@lru_cache()
def train_clean_100_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_clean_360_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def train_other_500_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def dev_clean_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for dev-clean")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp")
)
@lru_cache()
def dev_other_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)
@lru_cache()
def test_clean_xvector(self) -> KaldiReader:
logging.info("About to get speaker embeddings for test-clean")
return KaldiReader(
str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp")
)
@lru_cache()
def test_other_xvector(self) -> KaldiReader:
raise NotImplementedError(
"Please implement the method to load speaker embeddings."
)