mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add valle
This commit is contained in:
parent
57451b0382
commit
5361ecdc56
575
egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py
Executable file
575
egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py
Executable file
@ -0,0 +1,575 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Phonemize Text and EnCodec Audio.
|
||||
|
||||
Usage example:
|
||||
python3 bin/tokenizer.py \
|
||||
--src_dir ./data/manifests --output_dir ./data/tokenized
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from icefall.utils import get_executor
|
||||
from lhotse import CutSet, NumpyHdf5Writer
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from valle.data import (
|
||||
AudioTokenConfig,
|
||||
AudioTokenExtractor,
|
||||
TextTokenizer,
|
||||
tokenize_text,
|
||||
)
|
||||
# from valle.data.fbank import get_fbank_extractor
|
||||
from valle.utils import SymbolTable
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--src-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to the manifest files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("data/tokenized"),
|
||||
help="Path to the tokenized files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-extractor",
|
||||
type=str,
|
||||
default="espeak",
|
||||
help="espeak or pypinyin or pypinyin_initials_finals",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-extractor",
|
||||
type=str,
|
||||
default="Encodec",
|
||||
help="Encodec or Fbank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-parts",
|
||||
type=str,
|
||||
default="dev-clean test-clean",
|
||||
help="Space separated dataset parts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="libritts",
|
||||
help="prefix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
default="jsonl.gz",
|
||||
help="suffix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Split the cut_set into multiple parts",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
class PypinyinBackend:
|
||||
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
||||
There are two types pinyin or initials_finals, one is
|
||||
just like "ni1 hao3", the other is like "n i1 h ao3".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend="initials_finals",
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.punctuation_marks = punctuation_marks
|
||||
|
||||
def phonemize(
|
||||
self, text: List[str], separator: Separator, strip=True, njobs=1
|
||||
) -> List[str]:
|
||||
assert isinstance(text, List)
|
||||
phonemized = []
|
||||
for _text in text:
|
||||
_text = re.sub(" +", " ", _text.strip())
|
||||
_text = _text.replace(" ", separator.word)
|
||||
phones = []
|
||||
if self.backend == "pypinyin":
|
||||
for n, py in enumerate(
|
||||
pinyin(
|
||||
_text, style=Style.TONE3, neutral_tone_with_five=True
|
||||
)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
phones.extend([py[0], separator.syllable])
|
||||
elif self.backend == "pypinyin_initials_finals":
|
||||
for n, py in enumerate(
|
||||
pinyin(
|
||||
_text, style=Style.TONE3, neutral_tone_with_five=True
|
||||
)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
if py[0][-1].isalnum():
|
||||
initial = get_initials(py[0], strict=False)
|
||||
if py[0][-1].isdigit():
|
||||
final = (
|
||||
get_finals(py[0][:-1], strict=False)
|
||||
+ py[0][-1]
|
||||
)
|
||||
else:
|
||||
final = get_finals(py[0], strict=False)
|
||||
phones.extend(
|
||||
[
|
||||
initial,
|
||||
separator.phone,
|
||||
final,
|
||||
separator.syllable,
|
||||
]
|
||||
)
|
||||
else:
|
||||
assert ValueError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
phonemized.append(
|
||||
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
||||
)
|
||||
return phonemized
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
"""Phonemize Text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language="en-us",
|
||||
backend="espeak",
|
||||
separator=Separator(word="_", syllable="-", phone="|"),
|
||||
preserve_punctuation=True,
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
with_stress: bool = False,
|
||||
tie: Union[bool, str] = False,
|
||||
language_switch: LanguageSwitch = "keep-flags",
|
||||
words_mismatch: WordMismatch = "ignore",
|
||||
) -> None:
|
||||
if backend == "espeak":
|
||||
phonemizer = EspeakBackend(
|
||||
language,
|
||||
punctuation_marks=punctuation_marks,
|
||||
preserve_punctuation=preserve_punctuation,
|
||||
with_stress=with_stress,
|
||||
tie=tie,
|
||||
language_switch=language_switch,
|
||||
words_mismatch=words_mismatch,
|
||||
)
|
||||
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
||||
phonemizer = PypinyinBackend(
|
||||
backend=backend,
|
||||
punctuation_marks=punctuation_marks + separator.word,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{backend}")
|
||||
|
||||
self.backend = phonemizer
|
||||
self.separator = separator
|
||||
|
||||
def to_list(self, phonemized: str) -> List[str]:
|
||||
fields = []
|
||||
for word in phonemized.split(self.separator.word):
|
||||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||||
fields.extend(
|
||||
[p for p in pp if p != self.separator.phone]
|
||||
+ [self.separator.word]
|
||||
)
|
||||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||||
self.separator.phone
|
||||
)
|
||||
return fields[:-1]
|
||||
|
||||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
phonemized = self.backend.phonemize(
|
||||
text, separator=self.separator, strip=strip, njobs=1
|
||||
)
|
||||
return [self.to_list(p) for p in phonemized]
|
||||
|
||||
|
||||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||||
phonemes = tokenizer([text.strip()])
|
||||
return phonemes[0] # k2symbols
|
||||
|
||||
|
||||
def remove_encodec_weight_norm(model):
|
||||
from encodec.modules import SConv1d
|
||||
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
encoder = model.encoder.model
|
||||
for key in encoder._modules:
|
||||
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = encoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(encoder._modules[key], SConv1d):
|
||||
remove_weight_norm(encoder._modules[key].conv.conv)
|
||||
|
||||
decoder = model.decoder.model
|
||||
for key in decoder._modules:
|
||||
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = decoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
||||
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
||||
elif isinstance(decoder._modules[key], SConv1d):
|
||||
remove_weight_norm(decoder._modules[key].conv.conv)
|
||||
|
||||
|
||||
class AudioTokenizer:
|
||||
"""EnCodec audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Any = None,
|
||||
) -> None:
|
||||
# Instantiate a pretrained EnCodec model
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
remove_encodec_weight_norm(model)
|
||||
|
||||
if not device:
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self._device = device
|
||||
|
||||
self.codec = model.to(device)
|
||||
self.sample_rate = model.sample_rate
|
||||
self.channels = model.channels
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.encode(wav.to(self.device))
|
||||
|
||||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.decode(frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioTokenConfig:
|
||||
frame_shift: Seconds = 320.0 / 24000
|
||||
num_quantizers: int = 8
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
||||
return AudioTokenConfig(**data)
|
||||
|
||||
class AudioTokenExtractor(FeatureExtractor):
|
||||
name = "encodec"
|
||||
config_type = AudioTokenConfig
|
||||
|
||||
def __init__(self, config: Optional[Any] = None):
|
||||
super(AudioTokenExtractor, self).__init__(config)
|
||||
self.tokenizer = AudioTokenizer()
|
||||
|
||||
def extract(
|
||||
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
||||
) -> np.ndarray:
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = convert_audio(
|
||||
samples,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
if len(samples.shape) == 2:
|
||||
samples = samples.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
device = self.tokenizer.device
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
codes = encoded_frames[0][0] # [B, n_q, T]
|
||||
if True:
|
||||
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
||||
codes = codes[..., :expected_num_frames]
|
||||
return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.frame_shift
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.num_quantizers
|
||||
|
||||
def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
||||
# 计算每个张量的长度
|
||||
lengths = [tensor.shape[0] for tensor in tensor_list]
|
||||
# 使用pad_sequence函数进行填充
|
||||
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
||||
padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||
tensor_list, batch_first=True, padding_value=padding_value
|
||||
)
|
||||
return padded_tensor, lengths
|
||||
|
||||
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
||||
samples = [wav.squeeze() for wav in samples]
|
||||
device = self.tokenizer.device
|
||||
samples, lengths = self.pad_tensor_list(samples, device)
|
||||
samples = samples.unsqueeze(1)
|
||||
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if len(samples.shape) != 3:
|
||||
raise ValueError()
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = [
|
||||
convert_audio(
|
||||
wav,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
for wav in samples
|
||||
]
|
||||
samples = torch.stack(samples, 0) # convert samples from list to tensor
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
||||
batch_codes = []
|
||||
for b, length in enumerate(lengths):
|
||||
codes = encoded_frames[b]
|
||||
duration = round(length / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
batch_codes.append(codes[..., :expected_num_frames])
|
||||
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
|
||||
if dataset_parts == "all": # LibriTTS
|
||||
dataset_parts = [
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
]
|
||||
else:
|
||||
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
|
||||
|
||||
assert len(dataset_parts) >= 1
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=args.src_dir,
|
||||
prefix=args.prefix,
|
||||
suffix=args.suffix,
|
||||
types=["recordings", "supervisions", "cuts"],
|
||||
)
|
||||
|
||||
text_tokenizer = None
|
||||
if args.text_extractor:
|
||||
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
||||
|
||||
audio_extractor = None
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
|
||||
else:
|
||||
assert args.audio_extractor == "Fbank"
|
||||
audio_extractor = get_fbank_extractor()
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
unique_symbols = set()
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
|
||||
|
||||
prefix = args.prefix
|
||||
if prefix and not prefix.endswith("_"):
|
||||
prefix = f"{prefix}_"
|
||||
with get_executor() as ex:
|
||||
for partition, m in manifests.items():
|
||||
logging.info(
|
||||
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
|
||||
)
|
||||
try:
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
except Exception:
|
||||
cut_set = m["cuts"]
|
||||
|
||||
# Split cut_set if split > 1
|
||||
split = 1
|
||||
if args.split > 1:
|
||||
cut_sets = cut_set.split(args.split)
|
||||
split = args.split
|
||||
else:
|
||||
cut_sets = [cut_set]
|
||||
|
||||
for idx, part in enumerate(cut_sets):
|
||||
# AudioTokenizer
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
||||
)
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
||||
)
|
||||
|
||||
if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]:
|
||||
part = part.resample(24000)
|
||||
|
||||
with torch.no_grad():
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and args.audio_extractor == "Encodec"
|
||||
):
|
||||
part = part.compute_and_store_features_batch(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_workers=num_jobs,
|
||||
batch_duration=args.batch_duration,
|
||||
collate=False,
|
||||
overwrite=True,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
else:
|
||||
part = part.compute_and_store_features(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs if ex is None else 64,
|
||||
executor=ex,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
|
||||
# TextTokenizer
|
||||
if args.text_extractor:
|
||||
for c in tqdm(part):
|
||||
if args.prefix == "baker" and args.text_extractor == "labeled_pinyin":
|
||||
phonemes = c.supervisions[0].custom["tokens"]["text"]
|
||||
unique_symbols.update(phonemes)
|
||||
else:
|
||||
if args.prefix == "ljspeech":
|
||||
text = c.supervisions[0].custom["normalized_text"]
|
||||
text = text.replace(""", '"').replace(""", '"')
|
||||
phonemes = tokenize_text(text_tokenizer, text=text)
|
||||
elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]:
|
||||
phonemes = tokenize_text(
|
||||
text_tokenizer, text=c.supervisions[0].text
|
||||
)
|
||||
if c.supervisions[0].custom is None:
|
||||
c.supervisions[0].custom = {}
|
||||
else:
|
||||
raise NotImplementedError(f"{args.prefix}")
|
||||
c.supervisions[0].custom["tokens"] = {"text": phonemes}
|
||||
unique_symbols.update(phonemes)
|
||||
|
||||
# Save each part with an index if split > 1
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
if args.text_extractor:
|
||||
unique_phonemes = SymbolTable()
|
||||
for s in sorted(list(unique_symbols)):
|
||||
unique_phonemes.add(s)
|
||||
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
|
||||
|
||||
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
|
||||
unique_phonemes.to_file(unique_phonemes_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
304
egs/libritts/TTS/valle/infer.py
Normal file
304
egs/libritts/TTS/valle/infer.py
Normal file
@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
# Copyright 2024 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Phonemize Text and EnCodec Audio.
|
||||
|
||||
Usage example:
|
||||
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
|
||||
--checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
|
||||
--text-prompts "KNOT one point one five miles per hour." \
|
||||
--audio-prompts ./prompts/8463_294825_000043_000000.wav \
|
||||
--text "To get up and running quickly just follow the steps below."
|
||||
|
||||
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
|
||||
--top-k -1 --temperature 1.0 \
|
||||
--text-prompts "" \
|
||||
--audio-prompts "" \
|
||||
--text ./libritts.txt \
|
||||
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from valle.data import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
tokenize_audio,
|
||||
tokenize_text,
|
||||
)
|
||||
from valle.data.collation import get_text_token_collater
|
||||
from valle.models import get_model
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--text-prompts",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text prompts which are separated by |.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--audio-prompts",
|
||||
type=str,
|
||||
default="",
|
||||
help="Audio prompts which are separated by | and should be aligned with --text-prompts.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="To get up and running quickly just follow the steps below.",
|
||||
help="Text to be synthesized.",
|
||||
)
|
||||
|
||||
# model
|
||||
# add_model_arguments(parser)
|
||||
# parser.add_argument(
|
||||
# "--text-tokens",
|
||||
# type=str,
|
||||
# default="data/tokenized/unique_text_tokens.k2symbols",
|
||||
# help="Path to the unique text tokens file.",
|
||||
# )
|
||||
|
||||
parser.add_argument(
|
||||
"--text-extractor",
|
||||
type=str,
|
||||
default="espeak",
|
||||
help="espeak or pypinyin or pypinyin_initials_finals",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default="exp/vallf_nano_full/checkpoint-100000.pt",
|
||||
help="Path to the saved checkpoint.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("infer/demo"),
|
||||
help="Path to the tokenized files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=-100,
|
||||
help="Whether AR Decoder do top_k(if > 0) sampling.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Whether AR Decoder do top_p(if > 0) sampling.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The temperature of AR Decoder top_k sampling.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--continual",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Do continual task.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_model(checkpoint, device):
|
||||
if not checkpoint:
|
||||
return None
|
||||
|
||||
checkpoint = torch.load(checkpoint, map_location=device)
|
||||
|
||||
args = AttributeDict(checkpoint)
|
||||
model = get_model(args)
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
checkpoint["model"], strict=True
|
||||
)
|
||||
assert not missing_keys
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
text_tokens = args.text_tokens
|
||||
|
||||
return model, text_tokens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
model, text_tokens = load_model(args.checkpoint, device)
|
||||
text_collater = get_text_token_collater(text_tokens)
|
||||
|
||||
audio_tokenizer = AudioTokenizer()
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
text_prompts = " ".join(args.text_prompts.split("|"))
|
||||
|
||||
audio_prompts = []
|
||||
if args.audio_prompts:
|
||||
for n, audio_file in enumerate(args.audio_prompts.split("|")):
|
||||
encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
|
||||
if False:
|
||||
samples = audio_tokenizer.decode(encoded_frames)
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/p{n}.wav", samples[0], 24000
|
||||
)
|
||||
|
||||
audio_prompts.append(encoded_frames[0][0])
|
||||
|
||||
assert len(args.text_prompts.split("|")) == len(audio_prompts)
|
||||
audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1)
|
||||
audio_prompts = audio_prompts.to(device)
|
||||
|
||||
if os.path.isfile(args.text): # for demos
|
||||
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
|
||||
with open(args.text) as f:
|
||||
for line in f:
|
||||
# fields = line.strip().split("\t")
|
||||
fields = line.strip().split(" ")
|
||||
fields = [item for item in fields if item]
|
||||
assert len(fields) == 4
|
||||
prompt_text, prompt_audio, text, audio_path = fields
|
||||
logging.info(f"synthesize text: {text}")
|
||||
text_tokens, text_tokens_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{prompt_text} {text}".strip()
|
||||
)
|
||||
]
|
||||
)
|
||||
_, enroll_x_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{prompt_text}".strip()
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
|
||||
audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device)
|
||||
|
||||
# synthesis
|
||||
encoded_frames = model.inference(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
enroll_x_lens=enroll_x_lens,
|
||||
top_k=args.top_k,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
|
||||
samples = audio_tokenizer.decode(
|
||||
[(encoded_frames.transpose(2, 1), None)]
|
||||
)
|
||||
# store
|
||||
# save audio path into args.output_dir + audio_path
|
||||
audio_path = f"{args.output_dir}/{audio_path}"
|
||||
# mkdir -p
|
||||
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
|
||||
torchaudio.save(audio_path, samples[0].cpu(), 24000)
|
||||
return
|
||||
|
||||
for n, text in enumerate(args.text.split("|")):
|
||||
logging.info(f"synthesize text: {text}")
|
||||
text_tokens, text_tokens_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{text_prompts} {text}".strip()
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# synthesis
|
||||
if args.continual:
|
||||
assert text == ""
|
||||
encoded_frames = model.continual(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
)
|
||||
else:
|
||||
enroll_x_lens = None
|
||||
if text_prompts:
|
||||
_, enroll_x_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{text_prompts}".strip()
|
||||
)
|
||||
]
|
||||
)
|
||||
encoded_frames = model.inference(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
enroll_x_lens=enroll_x_lens,
|
||||
top_k=args.top_k,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
|
||||
if audio_prompts != []:
|
||||
samples = audio_tokenizer.decode(
|
||||
[(encoded_frames.transpose(2, 1), None)]
|
||||
)
|
||||
# store
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
|
||||
)
|
||||
else: # Transformer
|
||||
pass
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/libritts/TTS/valle/optim.py
Symbolic link
1
egs/libritts/TTS/valle/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
121
egs/libritts/TTS/valle/tokenizer.py
Normal file
121
egs/libritts/TTS/valle/tokenizer.py
Normal file
@ -0,0 +1,121 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from k2 import SymbolTable
|
||||
|
||||
class TextTokenCollater:
|
||||
"""Collate list of text tokens
|
||||
|
||||
Map sentences to integers. Sentences are padded to equal length.
|
||||
Beginning and end-of-sequence symbols can be added.
|
||||
|
||||
Example:
|
||||
>>> token_collater = TextTokenCollater(text_tokens)
|
||||
>>> tokens_batch, tokens_lens = token_collater(text)
|
||||
|
||||
Returns:
|
||||
tokens_batch: IntTensor of shape (B, L)
|
||||
B: batch dimension, number of input sentences
|
||||
L: length of the longest sentence
|
||||
tokens_lens: IntTensor of shape (B,)
|
||||
Length of each sentence after adding <eos> and <bos>
|
||||
but before padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_tokens: List[str],
|
||||
add_eos: bool = True,
|
||||
add_bos: bool = True,
|
||||
pad_symbol: str = "<pad>",
|
||||
bos_symbol: str = "<bos>",
|
||||
eos_symbol: str = "<eos>",
|
||||
):
|
||||
self.pad_symbol = pad_symbol
|
||||
|
||||
self.add_eos = add_eos
|
||||
self.add_bos = add_bos
|
||||
|
||||
self.bos_symbol = bos_symbol
|
||||
self.eos_symbol = eos_symbol
|
||||
|
||||
unique_tokens = (
|
||||
[pad_symbol]
|
||||
+ ([bos_symbol] if add_bos else [])
|
||||
+ ([eos_symbol] if add_eos else [])
|
||||
+ sorted(text_tokens)
|
||||
)
|
||||
|
||||
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
||||
self.idx2token = [token for token in unique_tokens]
|
||||
|
||||
def index(
|
||||
self, tokens_list: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seqs, seq_lens = [], []
|
||||
for tokens in tokens_list:
|
||||
assert (
|
||||
all([True if s in self.token2idx else False for s in tokens])
|
||||
is True
|
||||
)
|
||||
seq = (
|
||||
([self.bos_symbol] if self.add_bos else [])
|
||||
+ list(tokens)
|
||||
+ ([self.eos_symbol] if self.add_eos else [])
|
||||
)
|
||||
seqs.append(seq)
|
||||
seq_lens.append(len(seq))
|
||||
|
||||
max_len = max(seq_lens)
|
||||
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
|
||||
seq.extend([self.pad_symbol] * (max_len - seq_len))
|
||||
|
||||
tokens = torch.from_numpy(
|
||||
np.array(
|
||||
[[self.token2idx[token] for token in seq] for seq in seqs],
|
||||
dtype=np.int64,
|
||||
)
|
||||
)
|
||||
tokens_lens = torch.IntTensor(seq_lens)
|
||||
|
||||
return tokens, tokens_lens
|
||||
|
||||
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tokens_seqs = [[p for p in text] for text in texts]
|
||||
max_len = len(max(tokens_seqs, key=len))
|
||||
|
||||
seqs = [
|
||||
([self.bos_symbol] if self.add_bos else [])
|
||||
+ list(seq)
|
||||
+ ([self.eos_symbol] if self.add_eos else [])
|
||||
+ [self.pad_symbol] * (max_len - len(seq))
|
||||
for seq in tokens_seqs
|
||||
]
|
||||
|
||||
tokens_batch = torch.from_numpy(
|
||||
np.array(
|
||||
[[self.token2idx[token] for token in seq] for seq in seqs],
|
||||
dtype=np.int64,
|
||||
)
|
||||
)
|
||||
|
||||
tokens_lens = torch.IntTensor(
|
||||
[
|
||||
len(seq) + int(self.add_eos) + int(self.add_bos)
|
||||
for seq in tokens_seqs
|
||||
]
|
||||
)
|
||||
|
||||
return tokens_batch, tokens_lens
|
||||
|
||||
|
||||
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
|
||||
text_tokens_path = Path(text_tokens_file)
|
||||
unique_tokens = SymbolTable.from_file(text_tokens_path)
|
||||
collater = TextTokenCollater(
|
||||
unique_tokens.symbols, add_bos=True, add_eos=True
|
||||
)
|
||||
return collater
|
1287
egs/libritts/TTS/valle/train.py
Executable file
1287
egs/libritts/TTS/valle/train.py
Executable file
File diff suppressed because it is too large
Load Diff
344
egs/libritts/TTS/valle/tts_datamodule.py
Normal file
344
egs/libritts/TTS/valle/tts_datamodule.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
# Copyright 2024 (Author: Yuekai Zhang)
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.features.io import KaldiReader
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
class TtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in TTS
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/tokenized"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--speaker-embeds",
|
||||
type=Path,
|
||||
default=Path("exp/xvector_nnet_1a/"),
|
||||
help="Path to directory with speaker embeddings.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="libritts",
|
||||
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="""Audio sampling rate.""",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
dev_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
dev_dl = DataLoader(
|
||||
validate,
|
||||
sampler=dev_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return dev_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||
)
|
1822
egs/libritts/TTS/valle/valle.py
Normal file
1822
egs/libritts/TTS/valle/valle.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user