mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
minor updates
This commit is contained in:
parent
dc0106a0d5
commit
8da9acd7e1
@ -31,6 +31,14 @@ from piper_phonemize import phonemize_espeak
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def remove_punc_to_upper(text: str) -> str:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
s_list = [x.upper() if x in tokens else " " for x in text]
|
||||
s = " ".join("".join(s_list).split()).strip()
|
||||
return s
|
||||
|
||||
def prepare_tokens_libritts():
|
||||
output_dir = Path("data/spectrogram")
|
||||
prefix = "libritts"
|
||||
@ -60,6 +68,8 @@ def prepare_tokens_libritts():
|
||||
for t in tokens_list:
|
||||
tokens.extend(t)
|
||||
cut.tokens = tokens
|
||||
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
|
||||
|
||||
new_cuts.append(cut)
|
||||
|
||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||
|
@ -21,7 +21,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
@ -29,6 +29,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.features.io import KaldiReader
|
||||
from lhotse.utils import fix_random_seed
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
@ -331,16 +332,22 @@ def prepare_input(
|
||||
batch: dict,
|
||||
tokenizer: Tokenizer,
|
||||
device: torch.device,
|
||||
speaker_map: Dict[str, int],
|
||||
speaker_map: KaldiReader,
|
||||
):
|
||||
"""Parse batch data"""
|
||||
|
||||
def parse_sids(batch: dict) -> List[str]:
|
||||
return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]]
|
||||
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
speakers = (
|
||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
|
||||
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
|
||||
.squeeze(1)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
@ -366,8 +373,9 @@ def train_one_epoch(
|
||||
scheduler_g: LRSchedulerType,
|
||||
scheduler_d: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
speaker_map: Dict[str, int],
|
||||
dev_dl: torch.utils.data.DataLoader,
|
||||
train_speaker_map: KaldiReader,
|
||||
dev_speaker_map: KaldiReader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -442,7 +450,7 @@ def train_one_epoch(
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -457,7 +465,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
@ -476,7 +484,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
@ -583,8 +591,8 @@ def train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
speaker_map=speaker_map,
|
||||
dev_dl=dev_dl,
|
||||
dev_speaker_map=dev_speaker_map,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
@ -620,8 +628,8 @@ def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
speaker_map: Dict[str, int],
|
||||
dev_dl: torch.utils.data.DataLoader,
|
||||
dev_speaker_map: KaldiReader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
@ -634,7 +642,7 @@ def compute_validation_loss(
|
||||
returned_sample = None
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
for batch_idx, batch in enumerate(dev_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
(
|
||||
audio,
|
||||
@ -644,7 +652,7 @@ def compute_validation_loss(
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -657,7 +665,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
@ -672,7 +680,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
@ -687,7 +695,7 @@ def compute_validation_loss(
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, : tokens_lens[0].item()],
|
||||
sids=speakers[0],
|
||||
spembs=speakers[0],
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: torch.optim.Optimizer,
|
||||
optimizer_d: torch.optim.Optimizer,
|
||||
speaker_map: Dict[str, int],
|
||||
train_speaker_map: KaldiReader,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sids=speakers,
|
||||
spembs=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
@ -820,9 +828,12 @@ def run(rank, world_size, args):
|
||||
|
||||
libritts = LibrittsTtsDataModule(args)
|
||||
|
||||
train_cuts = libritts.train_cuts()
|
||||
speaker_map = libritts.speakers()
|
||||
params.num_spks = len(speaker_map)
|
||||
if params.full_libri:
|
||||
train_cuts = libritts.train_all_shuf_cuts()
|
||||
train_speaker_map = libritts.train_all_shuf_xvector()
|
||||
else:
|
||||
train_cuts = libritts.train_clean_460_cuts()
|
||||
train_speaker_map = libritts.train_clean_460_xvector()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -896,8 +907,9 @@ def run(rank, world_size, args):
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = libritts.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = libritts.valid_cuts()
|
||||
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
||||
dev_clean_cuts = libritts.dev_clean_cuts()
|
||||
dev_speaker_map = libritts.dev_clean_xvector()
|
||||
dev_dl = libritts.dev_dataloaders(dev_clean_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
@ -906,7 +918,7 @@ def run(rank, world_size, args):
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
speaker_map=speaker_map,
|
||||
train_speaker_map=train_speaker_map,
|
||||
params=params,
|
||||
)
|
||||
|
||||
@ -935,8 +947,9 @@ def run(rank, world_size, args):
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
speaker_map=speaker_map,
|
||||
dev_dl=dev_dl,
|
||||
train_speaker_map=train_speaker_map,
|
||||
dev_speaker_map=dev_speaker_map,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
|
@ -38,6 +38,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.features.io import KaldiReader
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@ -51,8 +52,10 @@ class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
LIBRITTS_SAMPLING_RATE = 24000
|
||||
|
||||
|
||||
class LibrittsTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
@ -82,7 +85,13 @@ class LibrittsTtsDataModule:
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, use the entire LibriTTS training set.
|
||||
Otherwise, use the 460h clean subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
@ -90,10 +99,10 @@ class LibrittsTtsDataModule:
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--speakers",
|
||||
"--speaker-embeds",
|
||||
type=Path,
|
||||
default=Path("data/speakers.txt"),
|
||||
help="Path to speakers.txt file.",
|
||||
default=Path("exp/xvector_nnet_1a/"),
|
||||
help="Path to directory with speaker embeddings.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
@ -141,7 +150,7 @@ class LibrittsTtsDataModule:
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
@ -175,7 +184,7 @@ class LibrittsTtsDataModule:
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
@ -191,7 +200,7 @@ class LibrittsTtsDataModule:
|
||||
use_fft_mag=True,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
@ -238,7 +247,7 @@ class LibrittsTtsDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = LIBRITTS_SAMPLING_RATE
|
||||
@ -249,7 +258,7 @@ class LibrittsTtsDataModule:
|
||||
use_fft_mag=True,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
@ -257,7 +266,7 @@ class LibrittsTtsDataModule:
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
@ -290,7 +299,7 @@ class LibrittsTtsDataModule:
|
||||
use_fft_mag=True,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
@ -298,7 +307,7 @@ class LibrittsTtsDataModule:
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
@ -319,23 +328,106 @@ class LibrittsTtsDataModule:
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||
def train_clean_460_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100 and train-clean-360 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ "libritts_cuts_with_tokens_train-clean-460.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def speakers(self) -> Dict[str, int]:
|
||||
logging.info("About to get speakers")
|
||||
with open(self.args.speakers) as f:
|
||||
speakers = {line.strip(): i for i, line in enumerate(f)}
|
||||
return speakers
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_460_xvector(self) -> KaldiReader:
|
||||
logging.info("About to get speaker embeddings for train-clean-460")
|
||||
return KaldiReader(
|
||||
str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp")
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_xvector(self) -> KaldiReader:
|
||||
logging.info("About to get speaker embeddings for dev-clean")
|
||||
return KaldiReader(
|
||||
str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp")
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_xvector(self) -> KaldiReader:
|
||||
logging.info("About to get speaker embeddings for test-clean")
|
||||
return KaldiReader(
|
||||
str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp")
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_xvector(self) -> KaldiReader:
|
||||
raise NotImplementedError(
|
||||
"Please implement the method to load speaker embeddings."
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user