mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
add infer code
This commit is contained in:
parent
5361ecdc56
commit
d55a534af8
@ -1,575 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Phonemize Text and EnCodec Audio.
|
||||
|
||||
Usage example:
|
||||
python3 bin/tokenizer.py \
|
||||
--src_dir ./data/manifests --output_dir ./data/tokenized
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from icefall.utils import get_executor
|
||||
from lhotse import CutSet, NumpyHdf5Writer
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from valle.data import (
|
||||
AudioTokenConfig,
|
||||
AudioTokenExtractor,
|
||||
TextTokenizer,
|
||||
tokenize_text,
|
||||
)
|
||||
# from valle.data.fbank import get_fbank_extractor
|
||||
from valle.utils import SymbolTable
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--src-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to the manifest files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("data/tokenized"),
|
||||
help="Path to the tokenized files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-extractor",
|
||||
type=str,
|
||||
default="espeak",
|
||||
help="espeak or pypinyin or pypinyin_initials_finals",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-extractor",
|
||||
type=str,
|
||||
default="Encodec",
|
||||
help="Encodec or Fbank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-parts",
|
||||
type=str,
|
||||
default="dev-clean test-clean",
|
||||
help="Space separated dataset parts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="libritts",
|
||||
help="prefix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
default="jsonl.gz",
|
||||
help="suffix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Split the cut_set into multiple parts",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
class PypinyinBackend:
|
||||
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
||||
There are two types pinyin or initials_finals, one is
|
||||
just like "ni1 hao3", the other is like "n i1 h ao3".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend="initials_finals",
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.punctuation_marks = punctuation_marks
|
||||
|
||||
def phonemize(
|
||||
self, text: List[str], separator: Separator, strip=True, njobs=1
|
||||
) -> List[str]:
|
||||
assert isinstance(text, List)
|
||||
phonemized = []
|
||||
for _text in text:
|
||||
_text = re.sub(" +", " ", _text.strip())
|
||||
_text = _text.replace(" ", separator.word)
|
||||
phones = []
|
||||
if self.backend == "pypinyin":
|
||||
for n, py in enumerate(
|
||||
pinyin(
|
||||
_text, style=Style.TONE3, neutral_tone_with_five=True
|
||||
)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
phones.extend([py[0], separator.syllable])
|
||||
elif self.backend == "pypinyin_initials_finals":
|
||||
for n, py in enumerate(
|
||||
pinyin(
|
||||
_text, style=Style.TONE3, neutral_tone_with_five=True
|
||||
)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
if py[0][-1].isalnum():
|
||||
initial = get_initials(py[0], strict=False)
|
||||
if py[0][-1].isdigit():
|
||||
final = (
|
||||
get_finals(py[0][:-1], strict=False)
|
||||
+ py[0][-1]
|
||||
)
|
||||
else:
|
||||
final = get_finals(py[0], strict=False)
|
||||
phones.extend(
|
||||
[
|
||||
initial,
|
||||
separator.phone,
|
||||
final,
|
||||
separator.syllable,
|
||||
]
|
||||
)
|
||||
else:
|
||||
assert ValueError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
phonemized.append(
|
||||
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
||||
)
|
||||
return phonemized
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
"""Phonemize Text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language="en-us",
|
||||
backend="espeak",
|
||||
separator=Separator(word="_", syllable="-", phone="|"),
|
||||
preserve_punctuation=True,
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
with_stress: bool = False,
|
||||
tie: Union[bool, str] = False,
|
||||
language_switch: LanguageSwitch = "keep-flags",
|
||||
words_mismatch: WordMismatch = "ignore",
|
||||
) -> None:
|
||||
if backend == "espeak":
|
||||
phonemizer = EspeakBackend(
|
||||
language,
|
||||
punctuation_marks=punctuation_marks,
|
||||
preserve_punctuation=preserve_punctuation,
|
||||
with_stress=with_stress,
|
||||
tie=tie,
|
||||
language_switch=language_switch,
|
||||
words_mismatch=words_mismatch,
|
||||
)
|
||||
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
||||
phonemizer = PypinyinBackend(
|
||||
backend=backend,
|
||||
punctuation_marks=punctuation_marks + separator.word,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{backend}")
|
||||
|
||||
self.backend = phonemizer
|
||||
self.separator = separator
|
||||
|
||||
def to_list(self, phonemized: str) -> List[str]:
|
||||
fields = []
|
||||
for word in phonemized.split(self.separator.word):
|
||||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||||
fields.extend(
|
||||
[p for p in pp if p != self.separator.phone]
|
||||
+ [self.separator.word]
|
||||
)
|
||||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||||
self.separator.phone
|
||||
)
|
||||
return fields[:-1]
|
||||
|
||||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
phonemized = self.backend.phonemize(
|
||||
text, separator=self.separator, strip=strip, njobs=1
|
||||
)
|
||||
return [self.to_list(p) for p in phonemized]
|
||||
|
||||
|
||||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||||
phonemes = tokenizer([text.strip()])
|
||||
return phonemes[0] # k2symbols
|
||||
|
||||
|
||||
def remove_encodec_weight_norm(model):
|
||||
from encodec.modules import SConv1d
|
||||
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
encoder = model.encoder.model
|
||||
for key in encoder._modules:
|
||||
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = encoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(encoder._modules[key], SConv1d):
|
||||
remove_weight_norm(encoder._modules[key].conv.conv)
|
||||
|
||||
decoder = model.decoder.model
|
||||
for key in decoder._modules:
|
||||
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = decoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
||||
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
||||
elif isinstance(decoder._modules[key], SConv1d):
|
||||
remove_weight_norm(decoder._modules[key].conv.conv)
|
||||
|
||||
|
||||
class AudioTokenizer:
|
||||
"""EnCodec audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Any = None,
|
||||
) -> None:
|
||||
# Instantiate a pretrained EnCodec model
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
remove_encodec_weight_norm(model)
|
||||
|
||||
if not device:
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self._device = device
|
||||
|
||||
self.codec = model.to(device)
|
||||
self.sample_rate = model.sample_rate
|
||||
self.channels = model.channels
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.encode(wav.to(self.device))
|
||||
|
||||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.decode(frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioTokenConfig:
|
||||
frame_shift: Seconds = 320.0 / 24000
|
||||
num_quantizers: int = 8
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
||||
return AudioTokenConfig(**data)
|
||||
|
||||
class AudioTokenExtractor(FeatureExtractor):
|
||||
name = "encodec"
|
||||
config_type = AudioTokenConfig
|
||||
|
||||
def __init__(self, config: Optional[Any] = None):
|
||||
super(AudioTokenExtractor, self).__init__(config)
|
||||
self.tokenizer = AudioTokenizer()
|
||||
|
||||
def extract(
|
||||
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
||||
) -> np.ndarray:
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = convert_audio(
|
||||
samples,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
if len(samples.shape) == 2:
|
||||
samples = samples.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
device = self.tokenizer.device
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
codes = encoded_frames[0][0] # [B, n_q, T]
|
||||
if True:
|
||||
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
||||
codes = codes[..., :expected_num_frames]
|
||||
return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.frame_shift
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.num_quantizers
|
||||
|
||||
def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
||||
# 计算每个张量的长度
|
||||
lengths = [tensor.shape[0] for tensor in tensor_list]
|
||||
# 使用pad_sequence函数进行填充
|
||||
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
||||
padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||
tensor_list, batch_first=True, padding_value=padding_value
|
||||
)
|
||||
return padded_tensor, lengths
|
||||
|
||||
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
||||
samples = [wav.squeeze() for wav in samples]
|
||||
device = self.tokenizer.device
|
||||
samples, lengths = self.pad_tensor_list(samples, device)
|
||||
samples = samples.unsqueeze(1)
|
||||
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if len(samples.shape) != 3:
|
||||
raise ValueError()
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = [
|
||||
convert_audio(
|
||||
wav,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
for wav in samples
|
||||
]
|
||||
samples = torch.stack(samples, 0) # convert samples from list to tensor
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
||||
batch_codes = []
|
||||
for b, length in enumerate(lengths):
|
||||
codes = encoded_frames[b]
|
||||
duration = round(length / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
batch_codes.append(codes[..., :expected_num_frames])
|
||||
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
|
||||
if dataset_parts == "all": # LibriTTS
|
||||
dataset_parts = [
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
]
|
||||
else:
|
||||
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
|
||||
|
||||
assert len(dataset_parts) >= 1
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=args.src_dir,
|
||||
prefix=args.prefix,
|
||||
suffix=args.suffix,
|
||||
types=["recordings", "supervisions", "cuts"],
|
||||
)
|
||||
|
||||
text_tokenizer = None
|
||||
if args.text_extractor:
|
||||
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
||||
|
||||
audio_extractor = None
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
|
||||
else:
|
||||
assert args.audio_extractor == "Fbank"
|
||||
audio_extractor = get_fbank_extractor()
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
unique_symbols = set()
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
|
||||
|
||||
prefix = args.prefix
|
||||
if prefix and not prefix.endswith("_"):
|
||||
prefix = f"{prefix}_"
|
||||
with get_executor() as ex:
|
||||
for partition, m in manifests.items():
|
||||
logging.info(
|
||||
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
|
||||
)
|
||||
try:
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
except Exception:
|
||||
cut_set = m["cuts"]
|
||||
|
||||
# Split cut_set if split > 1
|
||||
split = 1
|
||||
if args.split > 1:
|
||||
cut_sets = cut_set.split(args.split)
|
||||
split = args.split
|
||||
else:
|
||||
cut_sets = [cut_set]
|
||||
|
||||
for idx, part in enumerate(cut_sets):
|
||||
# AudioTokenizer
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
||||
)
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
||||
)
|
||||
|
||||
if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]:
|
||||
part = part.resample(24000)
|
||||
|
||||
with torch.no_grad():
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and args.audio_extractor == "Encodec"
|
||||
):
|
||||
part = part.compute_and_store_features_batch(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_workers=num_jobs,
|
||||
batch_duration=args.batch_duration,
|
||||
collate=False,
|
||||
overwrite=True,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
else:
|
||||
part = part.compute_and_store_features(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs if ex is None else 64,
|
||||
executor=ex,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
|
||||
# TextTokenizer
|
||||
if args.text_extractor:
|
||||
for c in tqdm(part):
|
||||
if args.prefix == "baker" and args.text_extractor == "labeled_pinyin":
|
||||
phonemes = c.supervisions[0].custom["tokens"]["text"]
|
||||
unique_symbols.update(phonemes)
|
||||
else:
|
||||
if args.prefix == "ljspeech":
|
||||
text = c.supervisions[0].custom["normalized_text"]
|
||||
text = text.replace(""", '"').replace(""", '"')
|
||||
phonemes = tokenize_text(text_tokenizer, text=text)
|
||||
elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]:
|
||||
phonemes = tokenize_text(
|
||||
text_tokenizer, text=c.supervisions[0].text
|
||||
)
|
||||
if c.supervisions[0].custom is None:
|
||||
c.supervisions[0].custom = {}
|
||||
else:
|
||||
raise NotImplementedError(f"{args.prefix}")
|
||||
c.supervisions[0].custom["tokens"] = {"text": phonemes}
|
||||
unique_symbols.update(phonemes)
|
||||
|
||||
# Save each part with an index if split > 1
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
if args.text_extractor:
|
||||
unique_phonemes = SymbolTable()
|
||||
for s in sorted(list(unique_symbols)):
|
||||
unique_phonemes.add(s)
|
||||
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
|
||||
|
||||
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
|
||||
unique_phonemes.to_file(unique_phonemes_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py
|
@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
cd vits/monotonic_align
|
||||
python setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
fi
|
||||
fi
|
||||
@ -75,11 +75,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute Spectrogram for LibriTTS"
|
||||
mkdir -p data/spectrogram
|
||||
if [ ! -e data/spectrogram/.libritts.done ]; then
|
||||
./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate
|
||||
./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate
|
||||
touch data/spectrogram/.libritts.done
|
||||
fi
|
||||
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||
# train-other-500 together to form the training set.
|
||||
if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
@ -88,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360
|
||||
# together to form the training set.
|
||||
if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare phoneme tokens for LibriTTS"
|
||||
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - piper_phonemize:
|
||||
# - piper_phonemize:
|
||||
# refer to https://github.com/rhasspy/piper-phonemize,
|
||||
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
|
||||
# - espnet_tts_frontend:
|
||||
# - espnet_tts_frontend:
|
||||
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/spectrogram/.libritts_with_token.done ]; then
|
||||
./local/prepare_tokens_libritts.py
|
||||
@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Generate token file"
|
||||
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - piper_phonemize:
|
||||
# - piper_phonemize:
|
||||
# refer to https://github.com/rhasspy/piper-phonemize,
|
||||
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
|
||||
# - espnet_tts_frontend:
|
||||
# - espnet_tts_frontend:
|
||||
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/tokens.txt ]; then
|
||||
./local/prepare_token_file.py --tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
audio_feats_dir=data/tokenized
|
||||
dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean"
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Tokenize/Fbank LibriTTS for valle"
|
||||
mkdir -p ${audio_feats_dir}
|
||||
if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then
|
||||
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
|
||||
--audio-extractor "Encodec" \
|
||||
--batch-duration 400 \
|
||||
--src-dir "data/manifests" \
|
||||
--output-dir "${audio_feats_dir}"
|
||||
fi
|
||||
touch ${audio_feats_dir}/.libritts.tokenize.done
|
||||
|
||||
lhotse combine \
|
||||
${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \
|
||||
${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \
|
||||
${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_train.jsonl.gz
|
||||
lhotse copy \
|
||||
${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_dev.jsonl.gz
|
||||
lhotse copy \
|
||||
${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_test.jsonl.gz
|
||||
fi
|
||||
|
1
egs/libritts/TTS/valle
Symbolic link
1
egs/libritts/TTS/valle
Symbolic link
@ -0,0 +1 @@
|
||||
../../wenetspeech4tts/TTS/valle/
|
51
egs/wenetspeech4tts/TTS/README.md
Normal file
51
egs/wenetspeech4tts/TTS/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Introduction
|
||||
|
||||
LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members.
|
||||
The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus.
|
||||
The main differences from the LibriSpeech corpus are listed below:
|
||||
1. The audio files are at 24kHz sampling rate.
|
||||
2. The speech is split at sentence breaks.
|
||||
3. Both original and normalized texts are included.
|
||||
4. Contextual information (e.g., neighbouring sentences) can be extracted.
|
||||
5. Utterances with significant background noise are excluded.
|
||||
For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced.
|
||||
|
||||
> [!CAUTION]
|
||||
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
|
||||
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
|
||||
>
|
||||
> By using this framework, you agree to the following:
|
||||
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
|
||||
>
|
||||
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
|
||||
>
|
||||
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
|
||||
>
|
||||
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
|
||||
|
||||
|
||||
# VITS
|
||||
|
||||
This recipe provides a VITS model trained on the LibriTTS dataset.
|
||||
|
||||
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30).
|
||||
|
||||
The training command is given below:
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 400 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--max-duration 500
|
||||
```
|
||||
|
||||
To inference, use:
|
||||
```
|
||||
./vits/infer.py \
|
||||
--exp-dir vits/exp \
|
||||
--epoch 400 \
|
||||
--tokens data/tokens.txt
|
||||
```
|
615
egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py
Executable file
615
egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py
Executable file
@ -0,0 +1,615 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Phonemize Text and EnCodec Audio.
|
||||
|
||||
Usage example:
|
||||
python3 bin/tokenizer.py \
|
||||
--src_dir ./data/manifests --output_dir ./data/tokenized
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from encodec import EncodecModel
|
||||
from encodec.utils import convert_audio
|
||||
from lhotse import CutSet, NumpyHdf5Writer
|
||||
from lhotse.features import FeatureExtractor
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
from phonemizer.backend import EspeakBackend
|
||||
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
||||
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
||||
from phonemizer.punctuation import Punctuation
|
||||
from phonemizer.separator import Separator
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
try:
|
||||
from pypinyin import Style, pinyin
|
||||
from pypinyin.style._utils import get_finals, get_initials
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import re
|
||||
from typing import Pattern
|
||||
|
||||
import numpy as np
|
||||
from k2 import SymbolTable
|
||||
|
||||
# from valle.data import (
|
||||
# AudioTokenConfig,
|
||||
# AudioTokenExtractor,
|
||||
# TextTokenizer,
|
||||
# tokenize_text,
|
||||
# )
|
||||
# from valle.data.fbank import get_fbank_extractor
|
||||
# from valle.utils import SymbolTable
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--src-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to the manifest files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("data/tokenized"),
|
||||
help="Path to the tokenized files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-extractor",
|
||||
type=str,
|
||||
default="espeak",
|
||||
help="espeak or pypinyin or pypinyin_initials_finals",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-extractor",
|
||||
type=str,
|
||||
default="Encodec",
|
||||
help="Encodec or Fbank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-parts",
|
||||
type=str,
|
||||
default="dev-clean test-clean",
|
||||
help="Space separated dataset parts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="libritts",
|
||||
help="prefix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
default="jsonl.gz",
|
||||
help="suffix of the manifest file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Split the cut_set into multiple parts",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class PypinyinBackend:
|
||||
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
||||
There are two types pinyin or initials_finals, one is
|
||||
just like "ni1 hao3", the other is like "n i1 h ao3".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend="initials_finals",
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.punctuation_marks = punctuation_marks
|
||||
|
||||
def phonemize(
|
||||
self, text: List[str], separator: Separator, strip=True, njobs=1
|
||||
) -> List[str]:
|
||||
assert isinstance(text, List)
|
||||
phonemized = []
|
||||
for _text in text:
|
||||
_text = re.sub(" +", " ", _text.strip())
|
||||
_text = _text.replace(" ", separator.word)
|
||||
phones = []
|
||||
if self.backend == "pypinyin":
|
||||
for n, py in enumerate(
|
||||
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
phones.extend([py[0], separator.syllable])
|
||||
elif self.backend == "pypinyin_initials_finals":
|
||||
for n, py in enumerate(
|
||||
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
|
||||
):
|
||||
if all([c in self.punctuation_marks for c in py[0]]):
|
||||
if len(phones):
|
||||
assert phones[-1] == separator.syllable
|
||||
phones.pop(-1)
|
||||
phones.extend(list(py[0]))
|
||||
else:
|
||||
if py[0][-1].isalnum():
|
||||
initial = get_initials(py[0], strict=False)
|
||||
if py[0][-1].isdigit():
|
||||
final = get_finals(py[0][:-1], strict=False) + py[0][-1]
|
||||
else:
|
||||
final = get_finals(py[0], strict=False)
|
||||
phones.extend(
|
||||
[
|
||||
initial,
|
||||
separator.phone,
|
||||
final,
|
||||
separator.syllable,
|
||||
]
|
||||
)
|
||||
else:
|
||||
assert ValueError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
phonemized.append(
|
||||
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
||||
)
|
||||
return phonemized
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
"""Phonemize Text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language="en-us",
|
||||
backend="espeak",
|
||||
separator=Separator(word="_", syllable="-", phone="|"),
|
||||
preserve_punctuation=True,
|
||||
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
||||
with_stress: bool = False,
|
||||
tie: Union[bool, str] = False,
|
||||
language_switch: LanguageSwitch = "keep-flags",
|
||||
words_mismatch: WordMismatch = "ignore",
|
||||
) -> None:
|
||||
if backend == "espeak":
|
||||
phonemizer = EspeakBackend(
|
||||
language,
|
||||
punctuation_marks=punctuation_marks,
|
||||
preserve_punctuation=preserve_punctuation,
|
||||
with_stress=with_stress,
|
||||
tie=tie,
|
||||
language_switch=language_switch,
|
||||
words_mismatch=words_mismatch,
|
||||
)
|
||||
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
||||
phonemizer = PypinyinBackend(
|
||||
backend=backend,
|
||||
punctuation_marks=punctuation_marks + separator.word,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{backend}")
|
||||
|
||||
self.backend = phonemizer
|
||||
self.separator = separator
|
||||
|
||||
def to_list(self, phonemized: str) -> List[str]:
|
||||
fields = []
|
||||
for word in phonemized.split(self.separator.word):
|
||||
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
||||
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
||||
fields.extend(
|
||||
[p for p in pp if p != self.separator.phone] + [self.separator.word]
|
||||
)
|
||||
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
||||
self.separator.phone
|
||||
)
|
||||
return fields[:-1]
|
||||
|
||||
def __call__(self, text, strip=True) -> List[List[str]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
phonemized = self.backend.phonemize(
|
||||
text, separator=self.separator, strip=strip, njobs=1
|
||||
)
|
||||
return [self.to_list(p) for p in phonemized]
|
||||
|
||||
|
||||
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
||||
phonemes = tokenizer([text.strip()])
|
||||
return phonemes[0] # k2symbols
|
||||
|
||||
|
||||
def remove_encodec_weight_norm(model):
|
||||
from encodec.modules import SConv1d
|
||||
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
encoder = model.encoder.model
|
||||
for key in encoder._modules:
|
||||
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = encoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(encoder._modules[key], SConv1d):
|
||||
remove_weight_norm(encoder._modules[key].conv.conv)
|
||||
|
||||
decoder = model.decoder.model
|
||||
for key in decoder._modules:
|
||||
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
||||
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
||||
block_modules = decoder._modules[key].block._modules
|
||||
for skey in block_modules:
|
||||
if isinstance(block_modules[skey], SConv1d):
|
||||
remove_weight_norm(block_modules[skey].conv.conv)
|
||||
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
||||
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
||||
elif isinstance(decoder._modules[key], SConv1d):
|
||||
remove_weight_norm(decoder._modules[key].conv.conv)
|
||||
|
||||
|
||||
class AudioTokenizer:
|
||||
"""EnCodec audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Any = None,
|
||||
) -> None:
|
||||
# Instantiate a pretrained EnCodec model
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
remove_encodec_weight_norm(model)
|
||||
|
||||
if not device:
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self._device = device
|
||||
|
||||
self.codec = model.to(device)
|
||||
self.sample_rate = model.sample_rate
|
||||
self.channels = model.channels
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.encode(wav.to(self.device))
|
||||
|
||||
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return self.codec.decode(frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioTokenConfig:
|
||||
frame_shift: Seconds = 320.0 / 24000
|
||||
num_quantizers: int = 8
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
||||
return AudioTokenConfig(**data)
|
||||
|
||||
|
||||
class AudioTokenExtractor(FeatureExtractor):
|
||||
name = "encodec"
|
||||
config_type = AudioTokenConfig
|
||||
|
||||
def __init__(self, config: Optional[Any] = None):
|
||||
super(AudioTokenExtractor, self).__init__(config)
|
||||
self.tokenizer = AudioTokenizer()
|
||||
|
||||
def extract(
|
||||
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
||||
) -> np.ndarray:
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = convert_audio(
|
||||
samples,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
if len(samples.shape) == 2:
|
||||
samples = samples.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
device = self.tokenizer.device
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
codes = encoded_frames[0][0] # [B, n_q, T]
|
||||
if True:
|
||||
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
||||
codes = codes[..., :expected_num_frames]
|
||||
return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.frame_shift
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.num_quantizers
|
||||
|
||||
def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
||||
lengths = [tensor.shape[0] for tensor in tensor_list]
|
||||
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
||||
padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||
tensor_list, batch_first=True, padding_value=padding_value
|
||||
)
|
||||
return padded_tensor, lengths
|
||||
|
||||
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
||||
samples = [wav.squeeze() for wav in samples]
|
||||
device = self.tokenizer.device
|
||||
samples, lengths = self.pad_tensor_list(samples, device)
|
||||
samples = samples.unsqueeze(1)
|
||||
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
if len(samples.shape) != 3:
|
||||
raise ValueError()
|
||||
if sampling_rate != self.tokenizer.sample_rate:
|
||||
samples = [
|
||||
convert_audio(
|
||||
wav,
|
||||
sampling_rate,
|
||||
self.tokenizer.sample_rate,
|
||||
self.tokenizer.channels,
|
||||
)
|
||||
for wav in samples
|
||||
]
|
||||
samples = torch.stack(samples, 0) # convert samples from list to tensor
|
||||
# Extract discrete codes from EnCodec
|
||||
with torch.no_grad():
|
||||
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
||||
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
||||
batch_codes = []
|
||||
for b, length in enumerate(lengths):
|
||||
codes = encoded_frames[b]
|
||||
duration = round(length / sampling_rate, ndigits=12)
|
||||
expected_num_frames = compute_num_frames(
|
||||
duration=duration,
|
||||
frame_shift=self.frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
batch_codes.append(codes[..., :expected_num_frames])
|
||||
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
|
||||
if dataset_parts == "all": # LibriTTS
|
||||
dataset_parts = [
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
]
|
||||
else:
|
||||
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
|
||||
|
||||
assert len(dataset_parts) >= 1
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=args.src_dir,
|
||||
prefix=args.prefix,
|
||||
suffix=args.suffix,
|
||||
types=["recordings", "supervisions", "cuts"],
|
||||
)
|
||||
|
||||
text_tokenizer = None
|
||||
if args.text_extractor:
|
||||
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
||||
|
||||
audio_extractor = None
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
|
||||
else:
|
||||
raise NotImplementedError(f"{args.audio_extractor}")
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
unique_symbols = set()
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
|
||||
|
||||
prefix = args.prefix
|
||||
if prefix and not prefix.endswith("_"):
|
||||
prefix = f"{prefix}_"
|
||||
with get_executor() as ex:
|
||||
for partition, m in manifests.items():
|
||||
logging.info(
|
||||
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
|
||||
)
|
||||
try:
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
except Exception:
|
||||
cut_set = m["cuts"]
|
||||
|
||||
# Split cut_set if split > 1
|
||||
split = 1
|
||||
if args.split > 1:
|
||||
cut_sets = cut_set.split(args.split)
|
||||
split = args.split
|
||||
else:
|
||||
cut_sets = [cut_set]
|
||||
|
||||
for idx, part in enumerate(cut_sets):
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
||||
else:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
||||
|
||||
if args.prefix.lower() in [
|
||||
"ljspeech",
|
||||
"aishell",
|
||||
"baker",
|
||||
"wenetspeech4tts",
|
||||
]:
|
||||
part = part.resample(24000)
|
||||
assert args.prefix_lower() in [
|
||||
"ljspeech",
|
||||
"aishell",
|
||||
"baker",
|
||||
"wenetspeech4tts",
|
||||
"libritts",
|
||||
"libritts-r",
|
||||
]
|
||||
with torch.no_grad():
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and args.audio_extractor == "Encodec"
|
||||
):
|
||||
part = part.compute_and_store_features_batch(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_workers=num_jobs,
|
||||
batch_duration=args.batch_duration,
|
||||
collate=False,
|
||||
overwrite=True,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
else:
|
||||
part = part.compute_and_store_features(
|
||||
extractor=audio_extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs if ex is None else 64,
|
||||
executor=ex,
|
||||
storage_type=NumpyHdf5Writer,
|
||||
)
|
||||
|
||||
# TextTokenizer
|
||||
if args.text_extractor:
|
||||
for c in tqdm(part):
|
||||
if (
|
||||
args.prefix == "baker"
|
||||
and args.text_extractor == "labeled_pinyin"
|
||||
):
|
||||
phonemes = c.supervisions[0].custom["tokens"]["text"]
|
||||
unique_symbols.update(phonemes)
|
||||
else:
|
||||
if args.prefix == "ljspeech":
|
||||
text = c.supervisions[0].custom["normalized_text"]
|
||||
text = text.replace(""", '"').replace(""", '"')
|
||||
phonemes = tokenize_text(text_tokenizer, text=text)
|
||||
elif args.prefix in [
|
||||
"aishell",
|
||||
"aishell2",
|
||||
"wenetspeech4tts",
|
||||
"libritts",
|
||||
"libritts-r",
|
||||
]:
|
||||
phonemes = tokenize_text(
|
||||
text_tokenizer, text=c.supervisions[0].text
|
||||
)
|
||||
if c.supervisions[0].custom is None:
|
||||
c.supervisions[0].custom = {}
|
||||
c.supervisions[0].normalized_text = c.supervisions[
|
||||
0
|
||||
].text
|
||||
else:
|
||||
raise NotImplementedError(f"{args.prefix}")
|
||||
c.supervisions[0].custom["tokens"] = {"text": phonemes}
|
||||
unique_symbols.update(phonemes)
|
||||
c.tokens = phonemes
|
||||
assert c.supervisions[
|
||||
0
|
||||
].normalized_text, "normalized_text is None"
|
||||
|
||||
# Save each part with an index if split > 1
|
||||
cuts_filename = (
|
||||
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
||||
)
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
if args.text_extractor:
|
||||
unique_phonemes = SymbolTable()
|
||||
for s in sorted(list(unique_symbols)):
|
||||
unique_phonemes.add(s)
|
||||
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
|
||||
|
||||
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
|
||||
unique_phonemes.to_file(unique_phonemes_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
53
egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py
Executable file
53
egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in the manifests.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/tokenized"),
|
||||
help="Path to the tokenized manifests.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
manifest_dir = args.manifest_dir or Path("data/tokenized")
|
||||
for part in ["train", "dev", "test"]:
|
||||
print(f"## {part}")
|
||||
cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz")
|
||||
cuts.describe()
|
||||
print("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
101
egs/wenetspeech4tts/TTS/prepare.sh
Executable file
101
egs/wenetspeech4tts/TTS/prepare.sh
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
j=16
|
||||
stage=2
|
||||
stop_stage=2
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
dataset_parts="-p Basic" # -p Premium for Premium dataset only
|
||||
|
||||
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
||||
audio_extractor="Encodec" # or Fbank
|
||||
audio_feats_dir=data/tokenized
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "dl_dir: $dl_dir"
|
||||
log "Stage 0: Download data"
|
||||
huggingface-cli login
|
||||
huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS
|
||||
|
||||
# Extract the downloaded data:
|
||||
for folder in Standard Premium Basic; do
|
||||
for file in "$dl_dir/$folder"/*.tar.gz; do
|
||||
tar -xzvf "$file" -C "$dl_dir/$folder"
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare wenetspeech4tts manifest"
|
||||
# We assume that you have downloaded the wenetspeech4tts corpus
|
||||
# to $dl_dir/wenetspeech4tts
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.wenetspeech4tts.done ]; then
|
||||
lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}"
|
||||
touch data/manifests/.wenetspeech4tts.done
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Tokenize/Fbank wenetspeech4tts"
|
||||
mkdir -p ${audio_feats_dir}
|
||||
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then
|
||||
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
|
||||
--text-extractor ${text_extractor} \
|
||||
--audio-extractor ${audio_extractor} \
|
||||
--batch-duration 2500 \
|
||||
--prefix "wenetspeech4tts" \
|
||||
--src-dir "data/manifests" \
|
||||
--split 100 \
|
||||
--output-dir "${audio_feats_dir}/${prefix}_baisc_split_100"
|
||||
fi
|
||||
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 13: Combine features for basic"
|
||||
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz ]; then
|
||||
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_baisc_split_100 -name "*.jsonl.gz")
|
||||
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 3: Prepare wenetspeech4tts train/dev/test"
|
||||
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
|
||||
|
||||
lhotse subset --first 400 \
|
||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_dev.jsonl.gz
|
||||
|
||||
lhotse subset --last 400 \
|
||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_test.jsonl.gz
|
||||
|
||||
lhotse copy \
|
||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
||||
${audio_feats_dir}/cuts_train.jsonl.gz
|
||||
|
||||
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
|
||||
fi
|
||||
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
|
||||
fi
|
1
egs/wenetspeech4tts/TTS/shared
Symbolic link
1
egs/wenetspeech4tts/TTS/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
@ -0,0 +1 @@
|
||||
../local/compute_neural_codec_and_prepare_text_tokens.py
|
@ -40,16 +40,17 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from valle.data import (
|
||||
from compute_neural_codec_and_prepare_text_tokens import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
tokenize_audio,
|
||||
tokenize_text,
|
||||
)
|
||||
from valle.data.collation import get_text_token_collater
|
||||
from valle.models import get_model
|
||||
from k2 import symbol_table
|
||||
from tokenizer import get_text_token_collater
|
||||
from valle import VALLE
|
||||
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -70,21 +71,12 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
"--manifest",
|
||||
type=str,
|
||||
default="To get up and running quickly just follow the steps below.",
|
||||
help="Text to be synthesized.",
|
||||
default="",
|
||||
help="prompt text\t prompt audio\ttarget text\ttarget audio",
|
||||
)
|
||||
|
||||
# model
|
||||
# add_model_arguments(parser)
|
||||
# parser.add_argument(
|
||||
# "--text-tokens",
|
||||
# type=str,
|
||||
# default="data/tokenized/unique_text_tokens.k2symbols",
|
||||
# help="Path to the unique text tokens file.",
|
||||
# )
|
||||
|
||||
parser.add_argument(
|
||||
"--text-extractor",
|
||||
type=str,
|
||||
@ -143,8 +135,19 @@ def load_model(checkpoint, device):
|
||||
|
||||
checkpoint = torch.load(checkpoint, map_location=device)
|
||||
|
||||
args = AttributeDict(checkpoint)
|
||||
model = get_model(args)
|
||||
params = AttributeDict(checkpoint)
|
||||
model = VALLE(
|
||||
params.decoder_dim,
|
||||
params.nhead,
|
||||
params.num_decoder_layers,
|
||||
norm_first=params.norm_first,
|
||||
add_prenet=params.add_prenet,
|
||||
prefix_mode=params.prefix_mode,
|
||||
share_embedding=params.share_embedding,
|
||||
nar_scale_factor=params.scale_factor,
|
||||
prepend_bos=params.prepend_bos,
|
||||
num_quantizers=params.num_quantizers,
|
||||
)
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
checkpoint["model"], strict=True
|
||||
@ -153,9 +156,7 @@ def load_model(checkpoint, device):
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
text_tokens = args.text_tokens
|
||||
|
||||
return model, text_tokens
|
||||
return model, params.text_tokens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -181,9 +182,7 @@ def main():
|
||||
encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
|
||||
if False:
|
||||
samples = audio_tokenizer.decode(encoded_frames)
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/p{n}.wav", samples[0], 24000
|
||||
)
|
||||
torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000)
|
||||
|
||||
audio_prompts.append(encoded_frames[0][0])
|
||||
|
||||
@ -195,8 +194,8 @@ def main():
|
||||
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
|
||||
with open(args.text) as f:
|
||||
for line in f:
|
||||
# fields = line.strip().split("\t")
|
||||
fields = line.strip().split(" ")
|
||||
fields = line.strip().split("\t")
|
||||
# fields = line.strip().split(" ")
|
||||
fields = [item for item in fields if item]
|
||||
assert len(fields) == 4
|
||||
prompt_text, prompt_audio, text, audio_path = fields
|
||||
@ -209,11 +208,7 @@ def main():
|
||||
]
|
||||
)
|
||||
_, enroll_x_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{prompt_text}".strip()
|
||||
)
|
||||
]
|
||||
[tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())]
|
||||
)
|
||||
|
||||
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
|
||||
@ -244,11 +239,7 @@ def main():
|
||||
for n, text in enumerate(args.text.split("|")):
|
||||
logging.info(f"synthesize text: {text}")
|
||||
text_tokens, text_tokens_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{text_prompts} {text}".strip()
|
||||
)
|
||||
]
|
||||
[tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())]
|
||||
)
|
||||
|
||||
# synthesis
|
||||
@ -263,11 +254,7 @@ def main():
|
||||
enroll_x_lens = None
|
||||
if text_prompts:
|
||||
_, enroll_x_lens = text_collater(
|
||||
[
|
||||
tokenize_text(
|
||||
text_tokenizer, text=f"{text_prompts}".strip()
|
||||
)
|
||||
]
|
||||
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
|
||||
)
|
||||
encoded_frames = model.inference(
|
||||
text_tokens.to(device),
|
||||
@ -280,13 +267,9 @@ def main():
|
||||
)
|
||||
|
||||
if audio_prompts != []:
|
||||
samples = audio_tokenizer.decode(
|
||||
[(encoded_frames.transpose(2, 1), None)]
|
||||
)
|
||||
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
|
||||
# store
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
|
||||
)
|
||||
torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000)
|
||||
else: # Transformer
|
||||
pass
|
||||
|
||||
@ -297,8 +280,6 @@ torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -3,9 +3,9 @@ from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from k2 import SymbolTable
|
||||
|
||||
|
||||
class TextTokenCollater:
|
||||
"""Collate list of text tokens
|
||||
|
||||
@ -52,15 +52,10 @@ class TextTokenCollater:
|
||||
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
||||
self.idx2token = [token for token in unique_tokens]
|
||||
|
||||
def index(
|
||||
self, tokens_list: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seqs, seq_lens = [], []
|
||||
for tokens in tokens_list:
|
||||
assert (
|
||||
all([True if s in self.token2idx else False for s in tokens])
|
||||
is True
|
||||
)
|
||||
assert all([True if s in self.token2idx else False for s in tokens]) is True
|
||||
seq = (
|
||||
([self.bos_symbol] if self.add_bos else [])
|
||||
+ list(tokens)
|
||||
@ -103,10 +98,7 @@ class TextTokenCollater:
|
||||
)
|
||||
|
||||
tokens_lens = torch.IntTensor(
|
||||
[
|
||||
len(seq) + int(self.add_eos) + int(self.add_bos)
|
||||
for seq in tokens_seqs
|
||||
]
|
||||
[len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
|
||||
)
|
||||
|
||||
return tokens_batch, tokens_lens
|
||||
@ -115,7 +107,5 @@ class TextTokenCollater:
|
||||
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
|
||||
text_tokens_path = Path(text_tokens_file)
|
||||
unique_tokens = SymbolTable.from_file(text_tokens_path)
|
||||
collater = TextTokenCollater(
|
||||
unique_tokens.symbols, add_bos=True, add_eos=True
|
||||
)
|
||||
collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True)
|
||||
return collater
|
@ -49,10 +49,9 @@ import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -60,6 +59,19 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse import CutSet
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, ScaledAdam
|
||||
from tokenizer import TextTokenCollater, get_text_token_collater
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import TtsDataModule
|
||||
from valle import VALLE
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -70,22 +82,10 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from lhotse import CutSet
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tts_datamodule import TtsDataModule
|
||||
from optim import Eden, ScaledAdam
|
||||
from valle import VALLE
|
||||
from tokenizer import TextTokenCollater, get_text_token_collater
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
@ -95,6 +95,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
@ -159,6 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Number of Audio/Semantic quantization layers.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -363,7 +365,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
@ -568,9 +569,10 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device):
|
||||
"""Parse batch data"""
|
||||
|
||||
|
||||
features = batch["features"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
if "tokens" not in batch:
|
||||
@ -590,6 +592,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
|
||||
return features, features_lens, text_tokens, text_tokens_lens
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -615,11 +618,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
(
|
||||
audio_features,
|
||||
audio_features_lens,
|
||||
@ -684,9 +683,7 @@ def compute_validation_loss(
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
if params.visualize:
|
||||
output_dir = Path(
|
||||
f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}"
|
||||
)
|
||||
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(model, DDP):
|
||||
model.module.visualize(predicts, batch, output_dir=output_dir)
|
||||
@ -777,26 +774,21 @@ def train_one_epoch(
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (
|
||||
tot_loss * (1 - 1 / params.reset_interval)
|
||||
) + loss_info * (1 / params.reset_interval)
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
|
||||
1 / params.reset_interval
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
if params.batch_idx_train >= params.accumulate_grad_steps:
|
||||
if (
|
||||
params.batch_idx_train % params.accumulate_grad_steps
|
||||
== 0
|
||||
):
|
||||
if params.batch_idx_train % params.accumulate_grad_steps == 0:
|
||||
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
scaler.unscale_(optimizer)
|
||||
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), 1.0
|
||||
)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
@ -825,7 +817,7 @@ def train_one_epoch(
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
@ -849,15 +841,13 @@ def train_one_epoch(
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
if cur_grad_scale < 1.0 or (
|
||||
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
||||
):
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
|
||||
if cur_grad_scale < 0.01:
|
||||
@ -870,9 +860,7 @@ def train_one_epoch(
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = (
|
||||
scaler._scale.item()
|
||||
if params.dtype in ["float16", "fp16"]
|
||||
else 1.0
|
||||
scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0
|
||||
)
|
||||
|
||||
logging.info(
|
||||
@ -897,12 +885,8 @@ def train_one_epoch(
|
||||
"train/current_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.dtype in ["float16", "fp16"]:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
@ -922,9 +906,7 @@ def train_one_epoch(
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||
)
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
@ -1063,10 +1045,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
else:
|
||||
parameters_names.append(
|
||||
[
|
||||
name_param_pair[0]
|
||||
for name_param_pair in model.named_parameters()
|
||||
]
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
@ -1144,9 +1123,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(
|
||||
enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
|
||||
)
|
||||
scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
@ -52,6 +52,7 @@ class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class TtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
@ -301,9 +302,7 @@ class TtsDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
@ -341,4 +340,4 @@ class TtsDataModule:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
)
|
@ -12,36 +12,36 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
from typing import Dict, Iterator, List, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from icefall.utils import make_pad_mask
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
# from valle.data.input_strategies import PromptedFeatures
|
||||
# from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
# from valle.modules.transformer import (
|
||||
# AdaptiveLayerNorm,
|
||||
# LayerNorm,
|
||||
# TransformerEncoder,
|
||||
# TransformerEncoderLayer,
|
||||
# )
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
|
||||
from .visualizer import visualize
|
||||
|
||||
|
||||
class PromptedFeatures:
|
||||
def __init__(self, prompts, features):
|
||||
self.prompts = prompts
|
||||
self.features = features
|
||||
|
||||
def to(self, device):
|
||||
return PromptedFeatures(
|
||||
self.prompts.to(device), self.features.to(device)
|
||||
)
|
||||
return PromptedFeatures(self.prompts.to(device), self.features.to(device))
|
||||
|
||||
def sum(self):
|
||||
return self.features.sum()
|
||||
@ -54,6 +54,7 @@ class PromptedFeatures:
|
||||
def data(self):
|
||||
return (self.prompts, self.features)
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -114,9 +115,7 @@ class SinePositionalEmbedding(nn.Module):
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(
|
||||
0, x.size(1), dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.dim_model)
|
||||
@ -132,14 +131,17 @@ class SinePositionalEmbedding(nn.Module):
|
||||
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(output)
|
||||
|
||||
|
||||
class Transpose(nn.Identity):
|
||||
"""(N, T, D) -> (N, D, T)"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return input.transpose(1, 2)
|
||||
|
||||
|
||||
_shape_t = Union[int, List[int], torch.Size]
|
||||
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
from different representation subspaces as described in the paper:
|
||||
@ -221,9 +223,7 @@ class MultiheadAttention(Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = (
|
||||
self.kdim == embed_dim and self.vdim == embed_dim
|
||||
)
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
@ -234,12 +234,8 @@ class MultiheadAttention(Module):
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.bias_v = Parameter(
|
||||
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
@ -396,20 +392,18 @@ class MultiheadAttention(Module):
|
||||
)
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif (
|
||||
self.in_proj_bias is not None
|
||||
and query.dtype != self.in_proj_bias.dtype
|
||||
):
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (
|
||||
self.in_proj_weight is not None
|
||||
and query.dtype != self.in_proj_weight.dtype
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
@ -458,9 +452,7 @@ class MultiheadAttention(Module):
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument is neither CUDA nor CPU"
|
||||
)
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
@ -479,9 +471,7 @@ class MultiheadAttention(Module):
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
key_padding_mask
|
||||
if key_padding_mask is not None
|
||||
else attn_mask,
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1
|
||||
@ -506,9 +496,7 @@ class MultiheadAttention(Module):
|
||||
query, key = [x.transpose(1, 0) for x in (query, key)]
|
||||
value = key
|
||||
else:
|
||||
query, key, value = [
|
||||
x.transpose(1, 0) for x in (query, key, value)
|
||||
]
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
@ -722,14 +710,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.activation = activation
|
||||
|
||||
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
if layer_norm_cls == IdentityNorm:
|
||||
norm2 = BalancedBasicNorm(
|
||||
d_model, eps=layer_norm_eps, **factory_kwargs
|
||||
)
|
||||
else:
|
||||
norm2 = layer_norm_cls(
|
||||
d_model, eps=layer_norm_eps, **factory_kwargs
|
||||
)
|
||||
# if layer_norm_cls == IdentityNorm:
|
||||
# norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
# else:
|
||||
if True:
|
||||
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||
|
||||
if adaptive_layer_norm:
|
||||
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
||||
@ -887,7 +872,6 @@ class TransformerEncoder(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
@ -898,9 +882,8 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
||||
elif activation == "gelu":
|
||||
return F.gelu
|
||||
|
||||
raise RuntimeError(
|
||||
"activation should be relu/gelu, not {}".format(activation)
|
||||
)
|
||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||||
|
||||
|
||||
class VALLE(nn.Module):
|
||||
"""It implements https://arxiv.org/abs/2301.02111
|
||||
@ -1003,9 +986,7 @@ class VALLE(nn.Module):
|
||||
num_layers=num_layers,
|
||||
norm=LayerNorm(d_model) if norm_first else None,
|
||||
)
|
||||
self.ar_predict_layer = nn.Linear(
|
||||
d_model, NUM_AUDIO_TOKENS + 1, bias=False
|
||||
)
|
||||
self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
|
||||
|
||||
self.ar_accuracy_metric = MulticlassAccuracy(
|
||||
NUM_AUDIO_TOKENS + 1,
|
||||
@ -1034,21 +1015,15 @@ class VALLE(nn.Module):
|
||||
if add_prenet:
|
||||
self.nar_text_prenet = nn.Sequential(
|
||||
Transpose(),
|
||||
nn.Conv1d(
|
||||
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
||||
),
|
||||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||||
nn.BatchNorm1d(nar_d_model),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Conv1d(
|
||||
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
||||
),
|
||||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||||
nn.BatchNorm1d(nar_d_model),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Conv1d(
|
||||
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
||||
),
|
||||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||||
nn.BatchNorm1d(nar_d_model),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
@ -1092,9 +1067,7 @@ class VALLE(nn.Module):
|
||||
adaptive_layer_norm=True,
|
||||
),
|
||||
num_layers=int(num_layers * nar_scale_factor),
|
||||
norm=AdaptiveLayerNorm(
|
||||
nar_d_model, norm=nn.LayerNorm(nar_d_model)
|
||||
)
|
||||
norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
|
||||
if norm_first
|
||||
else None,
|
||||
)
|
||||
@ -1105,10 +1078,7 @@ class VALLE(nn.Module):
|
||||
]
|
||||
)
|
||||
self.nar_stage_embeddings = nn.ModuleList(
|
||||
[
|
||||
TokenEmbedding(nar_d_model, 1)
|
||||
for i in range(num_quantizers - 1)
|
||||
]
|
||||
[TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
|
||||
)
|
||||
|
||||
if share_embedding:
|
||||
@ -1119,9 +1089,9 @@ class VALLE(nn.Module):
|
||||
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
|
||||
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
|
||||
for j in range(0, num_quantizers - 2):
|
||||
self.nar_predict_layers[
|
||||
j
|
||||
].weight = self.nar_audio_embeddings[j + 2].weight
|
||||
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
|
||||
j + 2
|
||||
].weight
|
||||
|
||||
self.nar_accuracy_metric = MulticlassAccuracy(
|
||||
NUM_AUDIO_TOKENS + 1,
|
||||
@ -1192,13 +1162,9 @@ class VALLE(nn.Module):
|
||||
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
||||
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
||||
for j in range(1, self.num_quantizers):
|
||||
y_prompts += self.nar_audio_embeddings[j](
|
||||
codes[:, :prefix_len, j]
|
||||
)
|
||||
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
|
||||
if j < nar_stage:
|
||||
y_emb += self.nar_audio_embeddings[j](
|
||||
codes[:, prefix_len:, j]
|
||||
)
|
||||
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
|
||||
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
||||
elif self.prefix_mode in [2, 4]:
|
||||
if self.prefix_mode == 2:
|
||||
@ -1211,9 +1177,7 @@ class VALLE(nn.Module):
|
||||
y_prompts_codes.append(
|
||||
torch.clone(codes[b, start : start + prefix_len])
|
||||
)
|
||||
codes[
|
||||
b, start : start + prefix_len, nar_stage
|
||||
] = NUM_AUDIO_TOKENS
|
||||
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
|
||||
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
||||
else:
|
||||
prefix_len = y_prompts_codes.shape[1]
|
||||
@ -1221,9 +1185,7 @@ class VALLE(nn.Module):
|
||||
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
||||
y_emb = self.nar_audio_embeddings[0](y)
|
||||
for j in range(1, self.num_quantizers):
|
||||
y_prompts += self.nar_audio_embeddings[j](
|
||||
y_prompts_codes[..., j]
|
||||
)
|
||||
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
|
||||
if j < nar_stage:
|
||||
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
||||
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
||||
@ -1290,9 +1252,7 @@ class VALLE(nn.Module):
|
||||
text = x
|
||||
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
|
||||
|
||||
y, targets = self.pad_y_eos(
|
||||
codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
|
||||
)
|
||||
y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS)
|
||||
|
||||
x_len = x_lens.max()
|
||||
|
||||
@ -1408,21 +1368,16 @@ class VALLE(nn.Module):
|
||||
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
|
||||
if self.prefix_mode == 4:
|
||||
prefix_len = 0 # reset for Top10Accuracy metric
|
||||
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
|
||||
|
||||
# loss
|
||||
total_length = (y_lens).sum().type(torch.float32)
|
||||
total_loss += (
|
||||
F.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
ignore_index=NUM_AUDIO_TOKENS,
|
||||
reduction=reduction,
|
||||
)
|
||||
* (total_length / (total_length - prefix_len * x.shape[0]))
|
||||
)
|
||||
total_loss += F.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
ignore_index=NUM_AUDIO_TOKENS,
|
||||
reduction=reduction,
|
||||
) * (total_length / (total_length - prefix_len * x.shape[0]))
|
||||
metrics["NarTop10Accuracy"] = (
|
||||
self.nar_accuracy_metric(
|
||||
F.pad(
|
||||
@ -1505,24 +1460,27 @@ class VALLE(nn.Module):
|
||||
value=True,
|
||||
)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(
|
||||
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
|
||||
),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat(
|
||||
[x_attn_mask_pad, y_attn_mask], dim=0
|
||||
).to(y.device)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
y.device
|
||||
)
|
||||
|
||||
xy_dec, _ = self.ar_decoder(
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
ras=True
|
||||
ras = True
|
||||
samples = topk_sampling(
|
||||
logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y
|
||||
logits,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_aware_sampling=ras,
|
||||
preceding_tokens=y,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -1531,9 +1489,7 @@ class VALLE(nn.Module):
|
||||
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
|
||||
):
|
||||
if prompts.shape[1] == y.shape[1]:
|
||||
raise SyntaxError(
|
||||
"well trained model shouldn't reach here."
|
||||
)
|
||||
raise SyntaxError("well trained model shouldn't reach here.")
|
||||
|
||||
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
||||
break
|
||||
@ -1545,9 +1501,7 @@ class VALLE(nn.Module):
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
# Non-AR Decoders
|
||||
y_emb = self.nar_audio_embeddings[0](
|
||||
y[:, int(self.ar_audio_prepend_bos) :]
|
||||
)
|
||||
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
|
||||
|
||||
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
||||
enrolled_len = enroll_x_lens.max().item()
|
||||
@ -1586,15 +1540,11 @@ class VALLE(nn.Module):
|
||||
codes.append(samples)
|
||||
|
||||
if i < self.num_quantizers - 2:
|
||||
y_emb[:, :prefix_len] += embedding_layer(
|
||||
prompts[..., i + 1]
|
||||
)
|
||||
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||||
else:
|
||||
for j in range(1, self.num_quantizers):
|
||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
||||
prompts[..., j]
|
||||
)
|
||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
||||
|
||||
for i, (predict_layer, embedding_layer) in enumerate(
|
||||
zip(
|
||||
@ -1687,15 +1637,11 @@ class VALLE(nn.Module):
|
||||
codes.append(samples)
|
||||
|
||||
if i < 6:
|
||||
y_emb[:, :prefix_len] += embedding_layer(
|
||||
prompts[..., i + 1]
|
||||
)
|
||||
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||||
else:
|
||||
for j in range(1, 8):
|
||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
||||
prompts[..., j]
|
||||
)
|
||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
||||
|
||||
for i, (predict_layer, embedding_layer) in enumerate(
|
||||
zip(
|
||||
@ -1736,18 +1682,14 @@ def top_k_top_p_filtering(
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
if top_k > 0:
|
||||
top_k = min(
|
||||
max(top_k, min_tokens_to_keep), logits.size(-1)
|
||||
) # Safety check
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
@ -1755,9 +1697,7 @@ def top_k_top_p_filtering(
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1
|
||||
].clone()
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
@ -1768,7 +1708,14 @@ def top_k_top_p_filtering(
|
||||
return logits
|
||||
|
||||
|
||||
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None):
|
||||
def topk_sampling(
|
||||
logits,
|
||||
top_k=10,
|
||||
top_p=1.0,
|
||||
temperature=1.0,
|
||||
repetition_aware_sampling=False,
|
||||
preceding_tokens=None,
|
||||
):
|
||||
# temperature: (`optional`) float
|
||||
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
||||
# top_k: (`optional`) int
|
||||
@ -1780,11 +1727,13 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2)
|
||||
logits = top_k_top_p_filtering(
|
||||
logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
)
|
||||
# Sample
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# print top 10 value and index
|
||||
print("top 10 value and index", torch.topk(probs, 10), top_p)
|
||||
print("top 10 value and index", torch.topk(probs, 10), top_p)
|
||||
tokens = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
if repetition_aware_sampling:
|
||||
@ -1814,7 +1763,9 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
|
||||
probs = F.softmax(logits[i], dim=-1)
|
||||
token_new = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
print(f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}")
|
||||
print(
|
||||
f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}"
|
||||
)
|
||||
print("probs", probs, logits.shape)
|
||||
tokens[i] = token_new
|
||||
else:
|
Loading…
x
Reference in New Issue
Block a user