add infer code

This commit is contained in:
root 2024-11-19 08:12:54 +00:00
parent 5361ecdc56
commit d55a534af8
15 changed files with 1029 additions and 855 deletions

View File

@ -1,575 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/tokenizer.py \
--src_dir ./data/manifests --output_dir ./data/tokenized
"""
import argparse
import logging
import os
from pathlib import Path
import torch
import torch.multiprocessing
from icefall.utils import get_executor
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from tqdm.auto import tqdm
from valle.data import (
AudioTokenConfig,
AudioTokenExtractor,
TextTokenizer,
tokenize_text,
)
# from valle.data.fbank import get_fbank_extractor
from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = (
get_finals(py[0][:-1], strict=False)
+ py[0][-1]
)
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
# 计算每个张量的长度
lengths = [tensor.shape[0] for tensor in tensor_list]
# 使用pad_sequence函数进行填充
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
assert args.audio_extractor == "Fbank"
audio_extractor = get_fbank_extractor()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
# AudioTokenizer
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = (
f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
)
else:
storage_path = (
f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
)
if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]:
part = part.resample(24000)
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if args.prefix == "baker" and args.text_extractor == "labeled_pinyin":
phonemes = c.supervisions[0].custom["tokens"]["text"]
unique_symbols.update(phonemes)
else:
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
else:
raise NotImplementedError(f"{args.prefix}")
c.supervisions[0].custom["tokens"] = {"text": phonemes}
unique_symbols.update(phonemes)
# Save each part with an index if split > 1
cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
else
log "monotonic_align lib already built"
fi
fi
@ -75,11 +75,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute Spectrogram for LibriTTS"
mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.libritts.done ]; then
./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate
./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate
touch data/spectrogram/.libritts.done
fi
# Here we shuffle and combine the train-clean-100, train-clean-360 and
# Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set.
if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
@ -88,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz
fi
# Here we shuffle and combine the train-clean-100, train-clean-360
# Here we shuffle and combine the train-clean-100, train-clean-360
# together to form the training set.
if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LibriTTS"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize:
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.libritts_with_token.done ]; then
./local/prepare_tokens_libritts.py
@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Generate token file"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize:
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
audio_feats_dir=data/tokenized
dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Tokenize/Fbank LibriTTS for valle"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--audio-extractor "Encodec" \
--batch-duration 400 \
--src-dir "data/manifests" \
--output-dir "${audio_feats_dir}"
fi
touch ${audio_feats_dir}/.libritts.tokenize.done
lhotse combine \
${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
fi

1
egs/libritts/TTS/valle Symbolic link
View File

@ -0,0 +1 @@
../../wenetspeech4tts/TTS/valle/

View File

@ -0,0 +1,51 @@
# Introduction
LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members.
The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus.
The main differences from the LibriSpeech corpus are listed below:
1. The audio files are at 24kHz sampling rate.
2. The speech is split at sentence breaks.
3. Both original and normalized texts are included.
4. Contextual information (e.g., neighbouring sentences) can be extracted.
5. Utterances with significant background noise are excluded.
For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced.
> [!CAUTION]
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
>
> By using this framework, you agree to the following:
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
>
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
>
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
>
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
# VITS
This recipe provides a VITS model trained on the LibriTTS dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30).
The training command is given below:
```
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
./vits/train.py \
--world-size 4 \
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--max-duration 500
```
To inference, use:
```
./vits/infer.py \
--exp-dir vits/exp \
--epoch 400 \
--tokens data/tokens.txt
```

View File

@ -0,0 +1,615 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/tokenizer.py \
--src_dir ./data/manifests --output_dir ./data/tokenized
"""
import argparse
import logging
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import torch.multiprocessing
from encodec import EncodecModel
from encodec.utils import convert_audio
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.features import FeatureExtractor
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
from tqdm.auto import tqdm
from icefall.utils import get_executor
try:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials
except Exception:
pass
import re
from typing import Pattern
import numpy as np
from k2 import SymbolTable
# from valle.data import (
# AudioTokenConfig,
# AudioTokenExtractor,
# TextTokenizer,
# tokenize_text,
# )
# from valle.data.fbank import get_fbank_extractor
# from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = get_finals(py[0][:-1], strict=False) + py[0][-1]
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone] + [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
lengths = [tensor.shape[0] for tensor in tensor_list]
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
raise NotImplementedError(f"{args.audio_extractor}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
else:
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
if args.prefix.lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
]:
part = part.resample(24000)
assert args.prefix_lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
"libritts",
"libritts-r",
]
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if (
args.prefix == "baker"
and args.text_extractor == "labeled_pinyin"
):
phonemes = c.supervisions[0].custom["tokens"]["text"]
unique_symbols.update(phonemes)
else:
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in [
"aishell",
"aishell2",
"wenetspeech4tts",
"libritts",
"libritts-r",
]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
c.supervisions[0].normalized_text = c.supervisions[
0
].text
else:
raise NotImplementedError(f"{args.prefix}")
c.supervisions[0].custom["tokens"] = {"text": phonemes}
unique_symbols.update(phonemes)
c.tokens = phonemes
assert c.supervisions[
0
].normalized_text, "normalized_text is None"
# Save each part with an index if split > 1
cuts_filename = (
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
)
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in the manifests.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
"""
import argparse
from pathlib import Path
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized manifests.",
)
return parser.parse_args()
def main():
args = get_args()
manifest_dir = args.manifest_dir or Path("data/tokenized")
for part in ["train", "dev", "test"]:
print(f"## {part}")
cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz")
cuts.describe()
print("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
j=16
stage=2
stop_stage=2
dl_dir=$PWD/download
dataset_parts="-p Basic" # -p Premium for Premium dataset only
text_extractor="pypinyin_initials_finals" # default is espeak for English
audio_extractor="Encodec" # or Fbank
audio_feats_dir=data/tokenized
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "dl_dir: $dl_dir"
log "Stage 0: Download data"
huggingface-cli login
huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS
# Extract the downloaded data:
for folder in Standard Premium Basic; do
for file in "$dl_dir/$folder"/*.tar.gz; do
tar -xzvf "$file" -C "$dl_dir/$folder"
done
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare wenetspeech4tts manifest"
# We assume that you have downloaded the wenetspeech4tts corpus
# to $dl_dir/wenetspeech4tts
mkdir -p data/manifests
if [ ! -e data/manifests/.wenetspeech4tts.done ]; then
lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}"
touch data/manifests/.wenetspeech4tts.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Tokenize/Fbank wenetspeech4tts"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--text-extractor ${text_extractor} \
--audio-extractor ${audio_extractor} \
--batch-duration 2500 \
--prefix "wenetspeech4tts" \
--src-dir "data/manifests" \
--split 100 \
--output-dir "${audio_feats_dir}/${prefix}_baisc_split_100"
fi
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 13: Combine features for basic"
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz ]; then
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_baisc_split_100 -name "*.jsonl.gz")
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 3: Prepare wenetspeech4tts train/dev/test"
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
lhotse subset --first 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse subset --last 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
lhotse copy \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
fi
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi

View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1 @@
../local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -40,16 +40,17 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import torch
import torchaudio
from icefall.utils import AttributeDict, str2bool
from valle.data import (
from compute_neural_codec_and_prepare_text_tokens import (
AudioTokenizer,
TextTokenizer,
tokenize_audio,
tokenize_text,
)
from valle.data.collation import get_text_token_collater
from valle.models import get_model
from k2 import symbol_table
from tokenizer import get_text_token_collater
from valle import VALLE
from icefall.utils import AttributeDict, str2bool
def get_args():
@ -70,21 +71,12 @@ def get_args():
)
parser.add_argument(
"--text",
"--manifest",
type=str,
default="To get up and running quickly just follow the steps below.",
help="Text to be synthesized.",
default="",
help="prompt text\t prompt audio\ttarget text\ttarget audio",
)
# model
# add_model_arguments(parser)
# parser.add_argument(
# "--text-tokens",
# type=str,
# default="data/tokenized/unique_text_tokens.k2symbols",
# help="Path to the unique text tokens file.",
# )
parser.add_argument(
"--text-extractor",
type=str,
@ -143,8 +135,19 @@ def load_model(checkpoint, device):
checkpoint = torch.load(checkpoint, map_location=device)
args = AttributeDict(checkpoint)
model = get_model(args)
params = AttributeDict(checkpoint)
model = VALLE(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
prefix_mode=params.prefix_mode,
share_embedding=params.share_embedding,
nar_scale_factor=params.scale_factor,
prepend_bos=params.prepend_bos,
num_quantizers=params.num_quantizers,
)
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
@ -153,9 +156,7 @@ def load_model(checkpoint, device):
model.to(device)
model.eval()
text_tokens = args.text_tokens
return model, text_tokens
return model, params.text_tokens
@torch.no_grad()
@ -181,9 +182,7 @@ def main():
encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
if False:
samples = audio_tokenizer.decode(encoded_frames)
torchaudio.save(
f"{args.output_dir}/p{n}.wav", samples[0], 24000
)
torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000)
audio_prompts.append(encoded_frames[0][0])
@ -195,8 +194,8 @@ def main():
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
with open(args.text) as f:
for line in f:
# fields = line.strip().split("\t")
fields = line.strip().split(" ")
fields = line.strip().split("\t")
# fields = line.strip().split(" ")
fields = [item for item in fields if item]
assert len(fields) == 4
prompt_text, prompt_audio, text, audio_path = fields
@ -209,11 +208,7 @@ def main():
]
)
_, enroll_x_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{prompt_text}".strip()
)
]
[tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())]
)
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
@ -244,11 +239,7 @@ def main():
for n, text in enumerate(args.text.split("|")):
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{text_prompts} {text}".strip()
)
]
[tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())]
)
# synthesis
@ -263,11 +254,7 @@ def main():
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[
tokenize_text(
text_tokenizer, text=f"{text_prompts}".strip()
)
]
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
@ -280,13 +267,9 @@ def main():
)
if audio_prompts != []:
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
# store
torchaudio.save(
f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
)
torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000)
else: # Transformer
pass
@ -297,8 +280,6 @@ torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -3,9 +3,9 @@ from typing import List, Tuple
import numpy as np
import torch
from k2 import SymbolTable
class TextTokenCollater:
"""Collate list of text tokens
@ -52,15 +52,10 @@ class TextTokenCollater:
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def index(
self, tokens_list: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], []
for tokens in tokens_list:
assert (
all([True if s in self.token2idx else False for s in tokens])
is True
)
assert all([True if s in self.token2idx else False for s in tokens]) is True
seq = (
([self.bos_symbol] if self.add_bos else [])
+ list(tokens)
@ -103,10 +98,7 @@ class TextTokenCollater:
)
tokens_lens = torch.IntTensor(
[
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in tokens_seqs
]
[len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
)
return tokens_batch, tokens_lens
@ -115,7 +107,5 @@ class TextTokenCollater:
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True)
return collater

View File

@ -49,10 +49,9 @@ import argparse
import copy
import logging
import os
from contextlib import nullcontext
import random
import warnings
from contextlib import nullcontext
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
@ -60,6 +59,19 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse import CutSet
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam
from tokenizer import TextTokenCollater, get_text_token_collater
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from valle import VALLE
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -70,22 +82,10 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from lhotse import CutSet
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from optim import Eden, ScaledAdam
from valle import VALLE
from tokenizer import TextTokenCollater, get_text_token_collater
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
@ -95,6 +95,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--decoder-dim",
@ -159,6 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of Audio/Semantic quantization layers.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -363,7 +365,6 @@ def get_parser():
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
@ -568,9 +569,10 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device):
"""Parse batch data"""
features = batch["features"].to(device)
features_lens = batch["features_lens"].to(device)
if "tokens" not in batch:
@ -590,6 +592,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
return features, features_lens, text_tokens, text_tokens_lens
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -615,11 +618,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(
audio_features,
audio_features_lens,
@ -684,9 +683,7 @@ def compute_validation_loss(
params.best_valid_loss = loss_value
if params.visualize:
output_dir = Path(
f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}"
)
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
output_dir.mkdir(parents=True, exist_ok=True)
if isinstance(model, DDP):
model.module.visualize(predicts, batch, output_dir=output_dir)
@ -777,26 +774,21 @@ def train_one_epoch(
is_training=True,
)
# summary stats
tot_loss = (
tot_loss * (1 - 1 / params.reset_interval)
) + loss_info * (1 / params.reset_interval)
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
1 / params.reset_interval
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
if params.batch_idx_train >= params.accumulate_grad_steps:
if (
params.batch_idx_train % params.accumulate_grad_steps
== 0
):
if params.batch_idx_train % params.accumulate_grad_steps == 0:
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(
model.parameters(), 1.0
)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
@ -825,7 +817,7 @@ def train_one_epoch(
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
@ -849,15 +841,13 @@ def train_one_epoch(
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
@ -870,9 +860,7 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = (
scaler._scale.item()
if params.dtype in ["float16", "fp16"]
else 1.0
scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0
)
logging.info(
@ -897,12 +885,8 @@ def train_one_epoch(
"train/current_",
params.batch_idx_train,
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.dtype in ["float16", "fp16"]:
tb_writer.add_scalar(
"train/grad_scale",
@ -922,9 +906,7 @@ def train_one_epoch(
valid_dl=valid_dl,
world_size=world_size,
)
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}"
)
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
@ -1063,10 +1045,7 @@ def run(rank, world_size, args):
)
else:
parameters_names.append(
[
name_param_pair[0]
for name_param_pair in model.named_parameters()
]
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam(
@ -1144,9 +1123,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(
enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
)
scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -52,6 +52,7 @@ class _SeedWorkers:
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class TtsDataModule:
"""
DataModule for tts experiments.
@ -301,9 +302,7 @@ class TtsDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
@lru_cache()
def dev_cuts(self) -> CutSet:
@ -341,4 +340,4 @@ class TtsDataModule:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)
)

View File

@ -12,36 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import numbers
import random
from typing import Dict, Iterator, List, Tuple, Union
from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torchmetrics.classification import MulticlassAccuracy
# from valle.data.input_strategies import PromptedFeatures
# from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
# from valle.modules.transformer import (
# AdaptiveLayerNorm,
# LayerNorm,
# TransformerEncoder,
# TransformerEncoderLayer,
# )
from icefall.utils import make_pad_mask
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
from .visualizer import visualize
class PromptedFeatures:
def __init__(self, prompts, features):
self.prompts = prompts
self.features = features
def to(self, device):
return PromptedFeatures(
self.prompts.to(device), self.features.to(device)
)
return PromptedFeatures(self.prompts.to(device), self.features.to(device))
def sum(self):
return self.features.sum()
@ -54,6 +54,7 @@ class PromptedFeatures:
def data(self):
return (self.prompts, self.features)
class TokenEmbedding(nn.Module):
def __init__(
self,
@ -114,9 +115,7 @@ class SinePositionalEmbedding(nn.Module):
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32
).unsqueeze(1)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.dim_model)
@ -132,14 +131,17 @@ class SinePositionalEmbedding(nn.Module):
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)
_shape_t = Union[int, List[int], torch.Size]
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
@ -221,9 +223,7 @@ class MultiheadAttention(Module):
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (
self.kdim == embed_dim and self.vdim == embed_dim
)
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
@ -234,12 +234,8 @@ class MultiheadAttention(Module):
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
@ -396,20 +392,18 @@ class MultiheadAttention(Module):
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (
self.in_proj_bias is not None
and query.dtype != self.in_proj_bias.dtype
):
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None
and query.dtype != self.in_proj_weight.dtype
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
@ -458,9 +452,7 @@ class MultiheadAttention(Module):
for x in tensor_args
]
):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU"
)
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
@ -479,9 +471,7 @@ class MultiheadAttention(Module):
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None
else attn_mask,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
@ -506,9 +496,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -722,14 +710,11 @@ class TransformerEncoderLayer(nn.Module):
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
# if layer_norm_cls == IdentityNorm:
# norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
# else:
if True:
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -887,7 +872,6 @@ class TransformerEncoder(nn.Module):
return output
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
@ -898,9 +882,8 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
elif activation == "gelu":
return F.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class VALLE(nn.Module):
"""It implements https://arxiv.org/abs/2301.02111
@ -1003,9 +986,7 @@ class VALLE(nn.Module):
num_layers=num_layers,
norm=LayerNorm(d_model) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(
d_model, NUM_AUDIO_TOKENS + 1, bias=False
)
self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
self.ar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1,
@ -1034,21 +1015,15 @@ class VALLE(nn.Module):
if add_prenet:
self.nar_text_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
@ -1092,9 +1067,7 @@ class VALLE(nn.Module):
adaptive_layer_norm=True,
),
num_layers=int(num_layers * nar_scale_factor),
norm=AdaptiveLayerNorm(
nar_d_model, norm=nn.LayerNorm(nar_d_model)
)
norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
if norm_first
else None,
)
@ -1105,10 +1078,7 @@ class VALLE(nn.Module):
]
)
self.nar_stage_embeddings = nn.ModuleList(
[
TokenEmbedding(nar_d_model, 1)
for i in range(num_quantizers - 1)
]
[TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
)
if share_embedding:
@ -1119,9 +1089,9 @@ class VALLE(nn.Module):
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
for j in range(0, num_quantizers - 2):
self.nar_predict_layers[
j
].weight = self.nar_audio_embeddings[j + 2].weight
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
j + 2
].weight
self.nar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1,
@ -1192,13 +1162,9 @@ class VALLE(nn.Module):
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](
codes[:, :prefix_len, j]
)
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](
codes[:, prefix_len:, j]
)
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)
elif self.prefix_mode in [2, 4]:
if self.prefix_mode == 2:
@ -1211,9 +1177,7 @@ class VALLE(nn.Module):
y_prompts_codes.append(
torch.clone(codes[b, start : start + prefix_len])
)
codes[
b, start : start + prefix_len, nar_stage
] = NUM_AUDIO_TOKENS
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
else:
prefix_len = y_prompts_codes.shape[1]
@ -1221,9 +1185,7 @@ class VALLE(nn.Module):
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](
y_prompts_codes[..., j]
)
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[..., j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)
@ -1290,9 +1252,7 @@ class VALLE(nn.Module):
text = x
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
y, targets = self.pad_y_eos(
codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
)
y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS)
x_len = x_lens.max()
@ -1408,21 +1368,16 @@ class VALLE(nn.Module):
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
if self.prefix_mode == 4:
prefix_len = 0 # reset for Top10Accuracy metric
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(
0, 2, 1
)
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
# loss
total_length = (y_lens).sum().type(torch.float32)
total_loss += (
F.cross_entropy(
logits,
targets,
ignore_index=NUM_AUDIO_TOKENS,
reduction=reduction,
)
* (total_length / (total_length - prefix_len * x.shape[0]))
)
total_loss += F.cross_entropy(
logits,
targets,
ignore_index=NUM_AUDIO_TOKENS,
reduction=reduction,
) * (total_length / (total_length - prefix_len * x.shape[0]))
metrics["NarTop10Accuracy"] = (
self.nar_accuracy_metric(
F.pad(
@ -1505,24 +1460,27 @@ class VALLE(nn.Module):
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0
).to(y.device)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.ar_decoder(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
ras=True
ras = True
samples = topk_sampling(
logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y
logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_aware_sampling=ras,
preceding_tokens=y,
)
if (
@ -1531,9 +1489,7 @@ class VALLE(nn.Module):
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
):
if prompts.shape[1] == y.shape[1]:
raise SyntaxError(
"well trained model shouldn't reach here."
)
raise SyntaxError("well trained model shouldn't reach here.")
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
break
@ -1545,9 +1501,7 @@ class VALLE(nn.Module):
return torch.stack(codes, dim=-1)
# Non-AR Decoders
y_emb = self.nar_audio_embeddings[0](
y[:, int(self.ar_audio_prepend_bos) :]
)
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
enrolled_len = enroll_x_lens.max().item()
@ -1586,15 +1540,11 @@ class VALLE(nn.Module):
codes.append(samples)
if i < self.num_quantizers - 2:
y_emb[:, :prefix_len] += embedding_layer(
prompts[..., i + 1]
)
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, self.num_quantizers):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
prompts[..., j]
)
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
for i, (predict_layer, embedding_layer) in enumerate(
zip(
@ -1687,15 +1637,11 @@ class VALLE(nn.Module):
codes.append(samples)
if i < 6:
y_emb[:, :prefix_len] += embedding_layer(
prompts[..., i + 1]
)
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, 8):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
prompts[..., j]
)
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
for i, (predict_layer, embedding_layer) in enumerate(
zip(
@ -1736,18 +1682,14 @@ def top_k_top_p_filtering(
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(
max(top_k, min_tokens_to_keep), logits.size(-1)
) # Safety check
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
@ -1755,9 +1697,7 @@ def top_k_top_p_filtering(
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
@ -1768,7 +1708,14 @@ def top_k_top_p_filtering(
return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None):
def topk_sampling(
logits,
top_k=10,
top_p=1.0,
temperature=1.0,
repetition_aware_sampling=False,
preceding_tokens=None,
):
# temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int
@ -1780,11 +1727,13 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2)
logits = top_k_top_p_filtering(
logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = F.softmax(logits, dim=-1)
# print top 10 value and index
print("top 10 value and index", torch.topk(probs, 10), top_p)
print("top 10 value and index", torch.topk(probs, 10), top_p)
tokens = torch.multinomial(probs, num_samples=1)
if repetition_aware_sampling:
@ -1814,7 +1763,9 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
probs = F.softmax(logits[i], dim=-1)
token_new = torch.multinomial(probs, num_samples=1)
print(f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}")
print(
f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}"
)
print("probs", probs, logits.shape)
tokens[i] = token_new
else: