Valle Recipe for WenetSpeech4TTS, LibriTTS, LibriTTS-R (#1805)

* add valle

* update readme
This commit is contained in:
Yuekai Zhang 2024-11-22 11:18:01 +08:00 committed by GitHub
parent 57451b0382
commit cbe012d54c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 4675 additions and 15 deletions

View File

@ -49,3 +49,54 @@ To inference, use:
--epoch 400 \
--tokens data/tokens.txt
```
# [VALL-E](https://arxiv.org/abs/2301.02111)
./valle contains the code for training VALL-E TTS model.
Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo).
Preparation:
```
bash prepare.sh --start-stage 4
```
The training command is given below:
```
world_size=8
exp_dir=exp/valle
## Train AR model
python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
--share-embedding true --norm-first true --add-prenet false \
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
--exp-dir ${exp_dir} --world-size ${world_size}
## Train NAR model
# cd ${exp_dir}
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
# cd -
python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
--share-embedding true --norm-first true --add-prenet false \
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
--exp-dir ${exp_dir} --world-size ${world_size}
```
To inference, use:
```
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts
top_p=1.0
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
--top-k -1 --temperature 1.0 \
--text ./libritts.txt \
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p}
```

View File

@ -0,0 +1 @@
../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -132,3 +132,30 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
audio_feats_dir=data/tokenized
dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Tokenize/Fbank LibriTTS for valle"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--audio-extractor "Encodec" \
--batch-duration 400 \
--src-dir "data/manifests" \
--output-dir "${audio_feats_dir}"
fi
touch ${audio_feats_dir}/.libritts.tokenize.done
lhotse combine \
${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
fi

1
egs/libritts/TTS/valle Symbolic link
View File

@ -0,0 +1 @@
../../wenetspeech4tts/TTS/valle/

View File

@ -0,0 +1,72 @@
# Introduction
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
> [!CAUTION]
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
>
> By using this framework, you agree to the following:
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
>
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
>
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
>
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
# [VALL-E](https://arxiv.org/abs/2301.02111)
./valle contains the code for training VALL-E TTS model.
Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo).
Preparation:
```
bash prepare.sh
```
The training command is given below:
```
world_size=8
exp_dir=exp/valle
## Train AR model
python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
--share-embedding true --norm-first true --add-prenet false \
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
--exp-dir ${exp_dir} --world-size ${world_size}
## Train NAR model
# cd ${exp_dir}
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
# cd -
python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
--share-embedding true --norm-first true --add-prenet false \
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
--exp-dir ${exp_dir} --world-size ${world_size}
```
To inference, use:
```
huggingface-cli login
huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts
top_p=1.0
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
--top-k -1 --temperature 1.0 \
--text ./aishell3.txt \
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
--text-extractor pypinyin_initials_finals --top-p ${top_p}
```
# Credits
- [vall-e](https://github.com/lifeiteng/vall-e)

View File

@ -0,0 +1,609 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--text-extractor ${text_extractor} \
--audio-extractor ${audio_extractor} \
--batch-duration 2500 --prefix "wenetspeech4tts" \
--src-dir "data/manifests" --split 100 \
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
"""
import argparse
import logging
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import torch.multiprocessing
from encodec import EncodecModel
from encodec.utils import convert_audio
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.features import FeatureExtractor
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
from tqdm.auto import tqdm
from icefall.utils import get_executor
try:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials
except Exception:
pass
import re
from typing import Pattern
import numpy as np
from k2 import SymbolTable
# from valle.data import (
# AudioTokenConfig,
# AudioTokenExtractor,
# TextTokenizer,
# tokenize_text,
# )
# from valle.data.fbank import get_fbank_extractor
# from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = get_finals(py[0][:-1], strict=False) + py[0][-1]
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone] + [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
lengths = [tensor.shape[0] for tensor in tensor_list]
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
raise NotImplementedError(f"{args.audio_extractor}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
else:
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
if args.prefix.lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
]:
part = part.resample(24000)
assert args.prefix.lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
"libritts",
"libritts-r",
]
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in [
"aishell",
"aishell2",
"wenetspeech4tts",
"libritts",
"libritts-r",
]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
c.supervisions[0].normalized_text = c.supervisions[0].text
else:
raise NotImplementedError(f"{args.prefix}")
unique_symbols.update(phonemes)
c.tokens = phonemes
assert c.supervisions[
0
].normalized_text, "normalized_text is None"
# Save each part with an index if split > 1
cuts_filename = (
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
)
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in the manifests.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
"""
import argparse
from pathlib import Path
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized manifests.",
)
return parser.parse_args()
def main():
args = get_args()
manifest_dir = args.manifest_dir or Path("data/tokenized")
for part in ["train", "dev", "test"]:
print(f"## {part}")
cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz")
cuts.describe()
print("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env bash
set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
stage=1
stop_stage=4
dl_dir=$PWD/download
dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data
text_extractor="pypinyin_initials_finals" # default is espeak for English
audio_extractor="Encodec" # or Fbank
audio_feats_dir=data/tokenized
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "dl_dir: $dl_dir"
log "Stage 0: Download data"
huggingface-cli login
huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS
# Extract the downloaded data:
for folder in Standard Premium Basic; do
for file in "$dl_dir/$folder"/*.tar.gz; do
tar -xzvf "$file" -C "$dl_dir/$folder"
done
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare wenetspeech4tts manifest"
# We assume that you have downloaded the wenetspeech4tts corpus
# to $dl_dir/wenetspeech4tts
mkdir -p data/manifests
if [ ! -e data/manifests/.wenetspeech4tts.done ]; then
lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}"
touch data/manifests/.wenetspeech4tts.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Tokenize/Fbank wenetspeech4tts"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--text-extractor ${text_extractor} \
--audio-extractor ${audio_extractor} \
--batch-duration 2500 --prefix "wenetspeech4tts" \
--src-dir "data/manifests" \
--split 100 \
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir}
fi
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Combine features"
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz")
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare wenetspeech4tts train/dev/test"
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
lhotse subset --first 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse subset --last 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
lhotse copy \
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
fi
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi

View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1 @@
../local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -0,0 +1,300 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
# Copyright 2024 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to synthesize speech from text prompts and audio prompts.
Usage example:
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
--checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
--text-prompts "KNOT one point one five miles per hour." \
--audio-prompts ./prompts/8463_294825_000043_000000.wav \
--text "To get up and running quickly just follow the steps below."
top_p=1.0
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
--top-k -1 --temperature 1.0 \
--text ./aishell3.txt \
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
--text-extractor pypinyin_initials_finals --top-p ${top_p}
"""
import argparse
import logging
import os
from pathlib import Path
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import torch
import torchaudio
from compute_neural_codec_and_prepare_text_tokens import (
AudioTokenizer,
TextTokenizer,
tokenize_text,
)
from encodec.utils import convert_audio
from k2 import symbol_table
from tokenizer import get_text_token_collater
from valle import VALLE
from icefall.utils import AttributeDict, str2bool
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text-prompts",
type=str,
default="",
help="Text prompts which are separated by |.",
)
parser.add_argument(
"--audio-prompts",
type=str,
default="",
help="Audio prompts which are separated by | and should be aligned with --text-prompts.",
)
parser.add_argument(
"--text",
type=str,
default="",
help="prompt text\t prompt audio\ttarget text\ttarget audio",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--checkpoint",
type=str,
default="exp/vallf_nano_full/checkpoint-100000.pt",
help="Path to the saved checkpoint.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("infer/demo"),
help="Path to the tokenized files.",
)
parser.add_argument(
"--top-k",
type=int,
default=-100,
help="Whether AR Decoder do top_k(if > 0) sampling.",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="Whether AR Decoder do top_p(if > 0) sampling.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="The temperature of AR Decoder top_k sampling.",
)
parser.add_argument(
"--continual",
type=str2bool,
default=False,
help="Do continual task.",
)
parser.add_argument(
"--repetition-aware-sampling",
type=str2bool,
default=False,
help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370",
)
return parser.parse_args()
def load_model(checkpoint, device):
if not checkpoint:
return None
checkpoint = torch.load(checkpoint, map_location=device)
params = AttributeDict(checkpoint)
model = VALLE(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
prefix_mode=params.prefix_mode,
share_embedding=params.share_embedding,
nar_scale_factor=params.scale_factor,
prepend_bos=params.prepend_bos,
num_quantizers=params.num_quantizers,
)
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
model.to(device)
model.eval()
return model, params.text_tokens
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
# Load and pre-process the audio waveform
wav, sr = torchaudio.load(audio_path)
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
wav = wav.unsqueeze(0)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = tokenizer.encode(wav)
return encoded_frames
@torch.no_grad()
def main():
args = get_args()
text_tokenizer = TextTokenizer(backend=args.text_extractor)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
model, text_tokens = load_model(args.checkpoint, device)
text_collater = get_text_token_collater(text_tokens)
audio_tokenizer = AudioTokenizer()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
text_prompts = " ".join(args.text_prompts.split("|"))
audio_prompts = []
if args.audio_prompts:
for n, audio_file in enumerate(args.audio_prompts.split("|")):
encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
if False:
samples = audio_tokenizer.decode(encoded_frames)
torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000)
audio_prompts.append(encoded_frames[0][0])
assert len(args.text_prompts.split("|")) == len(audio_prompts)
audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1)
audio_prompts = audio_prompts.to(device)
if os.path.isfile(args.text): # for demos
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
with open(args.text) as f:
for line in f:
fields = line.strip().split(" ")
fields = [item for item in fields if item]
assert len(fields) == 4
prompt_text, prompt_audio, text, audio_path = fields
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{prompt_text} {text}".strip()
)
]
)
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())]
)
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device)
# synthesis
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
)
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
# store
# save audio path into args.output_dir + audio_path
audio_path = f"{args.output_dir}/{audio_path}"
# mkdir -p
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
torchaudio.save(audio_path, samples[0].cpu(), 24000)
return
for n, text in enumerate(args.text.split("|")):
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())]
)
# synthesis
if args.continual:
assert text == ""
encoded_frames = model.continual(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
)
else:
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
)
if audio_prompts != []:
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
# store
torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000)
else: # Transformer
pass
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,111 @@
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
from k2 import SymbolTable
class TextTokenCollater:
"""Collate list of text tokens
Map sentences to integers. Sentences are padded to equal length.
Beginning and end-of-sequence symbols can be added.
Example:
>>> token_collater = TextTokenCollater(text_tokens)
>>> tokens_batch, tokens_lens = token_collater(text)
Returns:
tokens_batch: IntTensor of shape (B, L)
B: batch dimension, number of input sentences
L: length of the longest sentence
tokens_lens: IntTensor of shape (B,)
Length of each sentence after adding <eos> and <bos>
but before padding.
"""
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
unique_tokens = (
[pad_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ sorted(text_tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], []
for tokens in tokens_list:
assert all([True if s in self.token2idx else False for s in tokens]) is True
seq = (
([self.bos_symbol] if self.add_bos else [])
+ list(tokens)
+ ([self.eos_symbol] if self.add_eos else [])
)
seqs.append(seq)
seq_lens.append(len(seq))
max_len = max(seq_lens)
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
seq.extend([self.pad_symbol] * (max_len - seq_len))
tokens = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(seq_lens)
return tokens, tokens_lens
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_seqs = [[p for p in text] for text in texts]
max_len = len(max(tokens_seqs, key=len))
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
+ [self.pad_symbol] * (max_len - len(seq))
for seq in tokens_seqs
]
tokens_batch = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(
[len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
)
return tokens_batch, tokens_lens
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True)
return collater

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,343 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
# Copyright 2023 (authors: Feiteng Li)
# Copyright 2024 (Author: Yuekai Zhang)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class TtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in TTS
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--speaker-embeds",
type=Path,
default=Path("exp/xvector_nnet_1a/"),
help="Path to directory with speaker embeddings.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=4,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=False,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--dataset",
type=str,
default="libritts",
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Audio sampling rate.""",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
raise NotImplementedError
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
dev_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
dev_dl = DataLoader(
validate,
sampler=dev_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return dev_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)

File diff suppressed because it is too large Load Diff