add valle

This commit is contained in:
root 2024-11-19 04:21:32 +00:00
parent 57451b0382
commit 5361ecdc56
7 changed files with 4454 additions and 0 deletions

View File

@ -0,0 +1,575 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/tokenizer.py \
--src_dir ./data/manifests --output_dir ./data/tokenized
"""
import argparse
import logging
import os
from pathlib import Path
import torch
import torch.multiprocessing
from icefall.utils import get_executor
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from tqdm.auto import tqdm
from valle.data import (
AudioTokenConfig,
AudioTokenExtractor,
TextTokenizer,
tokenize_text,
)
# from valle.data.fbank import get_fbank_extractor
from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = (
get_finals(py[0][:-1], strict=False)
+ py[0][-1]
)
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
# 计算每个张量的长度
lengths = [tensor.shape[0] for tensor in tensor_list]
# 使用pad_sequence函数进行填充
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
assert args.audio_extractor == "Fbank"
audio_extractor = get_fbank_extractor()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
# AudioTokenizer
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = (
f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
)
else:
storage_path = (
f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
)
if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]:
part = part.resample(24000)
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if args.prefix == "baker" and args.text_extractor == "labeled_pinyin":
phonemes = c.supervisions[0].custom["tokens"]["text"]
unique_symbols.update(phonemes)
else:
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
else:
raise NotImplementedError(f"{args.prefix}")
c.supervisions[0].custom["tokens"] = {"text": phonemes}
unique_symbols.update(phonemes)
# Save each part with an index if split > 1
cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,304 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
# Copyright 2024 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
--checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
--text-prompts "KNOT one point one five miles per hour." \
--audio-prompts ./prompts/8463_294825_000043_000000.wav \
--text "To get up and running quickly just follow the steps below."
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
--top-k -1 --temperature 1.0 \
--text-prompts "" \
--audio-prompts "" \
--text ./libritts.txt \
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt
"""
import argparse
import logging
import os
from pathlib import Path
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import torch
import torchaudio
from icefall.utils import AttributeDict, str2bool
from valle.data import (
AudioTokenizer,
TextTokenizer,
tokenize_audio,
tokenize_text,
)
from valle.data.collation import get_text_token_collater
from valle.models import get_model
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text-prompts",
type=str,
default="",
help="Text prompts which are separated by |.",
)
parser.add_argument(
"--audio-prompts",
type=str,
default="",
help="Audio prompts which are separated by | and should be aligned with --text-prompts.",
)
parser.add_argument(
"--text",
type=str,
default="To get up and running quickly just follow the steps below.",
help="Text to be synthesized.",
)
# model
# add_model_arguments(parser)
# parser.add_argument(
# "--text-tokens",
# type=str,
# default="data/tokenized/unique_text_tokens.k2symbols",
# help="Path to the unique text tokens file.",
# )
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--checkpoint",
type=str,
default="exp/vallf_nano_full/checkpoint-100000.pt",
help="Path to the saved checkpoint.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("infer/demo"),
help="Path to the tokenized files.",
)
parser.add_argument(
"--top-k",
type=int,
default=-100,
help="Whether AR Decoder do top_k(if > 0) sampling.",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="Whether AR Decoder do top_p(if > 0) sampling.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="The temperature of AR Decoder top_k sampling.",
)
parser.add_argument(
"--continual",
type=str2bool,
default=False,
help="Do continual task.",
)
return parser.parse_args()
def load_model(checkpoint, device):
if not checkpoint:
return None
checkpoint = torch.load(checkpoint, map_location=device)
args = AttributeDict(checkpoint)
model = get_model(args)
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
model.to(device)
model.eval()
text_tokens = args.text_tokens
return model, text_tokens
@torch.no_grad()
def main():
args = get_args()
text_tokenizer = TextTokenizer(backend=args.text_extractor)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
model, text_tokens = load_model(args.checkpoint, device)
text_collater = get_text_token_collater(text_tokens)
audio_tokenizer = AudioTokenizer()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
text_prompts = " ".join(args.text_prompts.split("|"))
audio_prompts = []
if args.audio_prompts:
for n, audio_file in enumerate(args.audio_prompts.split("|")):
encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
if False:
samples = audio_tokenizer.decode(encoded_frames)
torchaudio.save(
f"{args.output_dir}/p{n}.wav", samples[0], 24000
)
audio_prompts.append(encoded_frames[0][0])
assert len(args.text_prompts.split("|")) == len(audio_prompts)
audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1)
audio_prompts = audio_prompts.to(device)
if os.path.isfile(args.text): # for demos
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
with open(args.text) as f:
for line in f:
# fields = line.strip().split("\t")
fields = line.strip().split(" ")
fields = [item for item in fields if item]
assert len(fields) == 4
prompt_text, prompt_audio, text, audio_path = fields
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{prompt_text} {text}".strip()
)
]
)
_, enroll_x_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{prompt_text}".strip()
)
]
)
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device)
# synthesis
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
)
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
# store
# save audio path into args.output_dir + audio_path
audio_path = f"{args.output_dir}/{audio_path}"
# mkdir -p
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
torchaudio.save(audio_path, samples[0].cpu(), 24000)
return
for n, text in enumerate(args.text.split("|")):
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{text_prompts} {text}".strip()
)
]
)
# synthesis
if args.continual:
assert text == ""
encoded_frames = model.continual(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
)
else:
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{text_prompts}".strip()
)
]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
)
if audio_prompts != []:
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
# store
torchaudio.save(
f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
)
else: # Transformer
pass
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,121 @@
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
from k2 import SymbolTable
class TextTokenCollater:
"""Collate list of text tokens
Map sentences to integers. Sentences are padded to equal length.
Beginning and end-of-sequence symbols can be added.
Example:
>>> token_collater = TextTokenCollater(text_tokens)
>>> tokens_batch, tokens_lens = token_collater(text)
Returns:
tokens_batch: IntTensor of shape (B, L)
B: batch dimension, number of input sentences
L: length of the longest sentence
tokens_lens: IntTensor of shape (B,)
Length of each sentence after adding <eos> and <bos>
but before padding.
"""
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
unique_tokens = (
[pad_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ sorted(text_tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def index(
self, tokens_list: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], []
for tokens in tokens_list:
assert (
all([True if s in self.token2idx else False for s in tokens])
is True
)
seq = (
([self.bos_symbol] if self.add_bos else [])
+ list(tokens)
+ ([self.eos_symbol] if self.add_eos else [])
)
seqs.append(seq)
seq_lens.append(len(seq))
max_len = max(seq_lens)
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
seq.extend([self.pad_symbol] * (max_len - seq_len))
tokens = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(seq_lens)
return tokens, tokens_lens
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_seqs = [[p for p in text] for text in texts]
max_len = len(max(tokens_seqs, key=len))
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
+ [self.pad_symbol] * (max_len - len(seq))
for seq in tokens_seqs
]
tokens_batch = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(
[
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in tokens_seqs
]
)
return tokens_batch, tokens_lens
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater

1287
egs/libritts/TTS/valle/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,344 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
# Copyright 2023 (authors: Feiteng Li)
# Copyright 2024 (Author: Yuekai Zhang)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.features.io import KaldiReader
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class TtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in TTS
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--speaker-embeds",
type=Path,
default=Path("exp/xvector_nnet_1a/"),
help="Path to directory with speaker embeddings.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=4,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=False,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--dataset",
type=str,
default="libritts",
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Audio sampling rate.""",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
raise NotImplementedError
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
dev_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
dev_dl = DataLoader(
validate,
sampler=dev_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return dev_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)

File diff suppressed because it is too large Load Diff