add infer code

This commit is contained in:
root 2024-11-19 08:12:54 +00:00
parent 5361ecdc56
commit d55a534af8
15 changed files with 1029 additions and 855 deletions

View File

@ -1,575 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/tokenizer.py \
--src_dir ./data/manifests --output_dir ./data/tokenized
"""
import argparse
import logging
import os
from pathlib import Path
import torch
import torch.multiprocessing
from icefall.utils import get_executor
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from tqdm.auto import tqdm
from valle.data import (
AudioTokenConfig,
AudioTokenExtractor,
TextTokenizer,
tokenize_text,
)
# from valle.data.fbank import get_fbank_extractor
from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = (
get_finals(py[0][:-1], strict=False)
+ py[0][-1]
)
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
# 计算每个张量的长度
lengths = [tensor.shape[0] for tensor in tensor_list]
# 使用pad_sequence函数进行填充
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
assert args.audio_extractor == "Fbank"
audio_extractor = get_fbank_extractor()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
# AudioTokenizer
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = (
f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
)
else:
storage_path = (
f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
)
if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]:
part = part.resample(24000)
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if args.prefix == "baker" and args.text_extractor == "labeled_pinyin":
phonemes = c.supervisions[0].custom["tokens"]["text"]
unique_symbols.update(phonemes)
else:
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
else:
raise NotImplementedError(f"{args.prefix}")
c.supervisions[0].custom["tokens"] = {"text": phonemes}
unique_symbols.update(phonemes)
# Save each part with an index if split > 1
cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align cd vits/monotonic_align
python setup.py build_ext --inplace python setup.py build_ext --inplace
cd ../../ cd ../../
else else
log "monotonic_align lib already built" log "monotonic_align lib already built"
fi fi
fi fi
@ -75,11 +75,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute Spectrogram for LibriTTS" log "Stage 2: Compute Spectrogram for LibriTTS"
mkdir -p data/spectrogram mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.libritts.done ]; then if [ ! -e data/spectrogram/.libritts.done ]; then
./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate
touch data/spectrogram/.libritts.done touch data/spectrogram/.libritts.done
fi fi
# Here we shuffle and combine the train-clean-100, train-clean-360 and # Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set. # train-other-500 together to form the training set.
if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
@ -88,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz
fi fi
# Here we shuffle and combine the train-clean-100, train-clean-360 # Here we shuffle and combine the train-clean-100, train-clean-360
# together to form the training set. # together to form the training set.
if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LibriTTS" log "Stage 3: Prepare phoneme tokens for LibriTTS"
# We assume you have installed piper_phonemize and espnet_tts_frontend. # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - piper_phonemize: # - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize, # refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend: # - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.libritts_with_token.done ]; then if [ ! -e data/spectrogram/.libritts_with_token.done ]; then
./local/prepare_tokens_libritts.py ./local/prepare_tokens_libritts.py
@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Generate token file" log "Stage 4: Generate token file"
# We assume you have installed piper_phonemize and espnet_tts_frontend. # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - piper_phonemize: # - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize, # refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend: # - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py --tokens data/tokens.txt ./local/prepare_token_file.py --tokens data/tokens.txt
fi fi
fi fi
audio_feats_dir=data/tokenized
dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Tokenize/Fbank LibriTTS for valle"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--audio-extractor "Encodec" \
--batch-duration 400 \
--src-dir "data/manifests" \
--output-dir "${audio_feats_dir}"
fi
touch ${audio_feats_dir}/.libritts.tokenize.done
lhotse combine \
${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \
${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse copy \
${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
fi

1
egs/libritts/TTS/valle Symbolic link
View File

@ -0,0 +1 @@
../../wenetspeech4tts/TTS/valle/

View File

@ -0,0 +1,51 @@
# Introduction
LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members.
The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus.
The main differences from the LibriSpeech corpus are listed below:
1. The audio files are at 24kHz sampling rate.
2. The speech is split at sentence breaks.
3. Both original and normalized texts are included.
4. Contextual information (e.g., neighbouring sentences) can be extracted.
5. Utterances with significant background noise are excluded.
For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced.
> [!CAUTION]
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities.
>
> By using this framework, you agree to the following:
> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data.
>
> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology.
>
> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required.
>
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
# VITS
This recipe provides a VITS model trained on the LibriTTS dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30).
The training command is given below:
```
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
./vits/train.py \
--world-size 4 \
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--max-duration 500
```
To inference, use:
```
./vits/infer.py \
--exp-dir vits/exp \
--epoch 400 \
--tokens data/tokens.txt
```

View File

@ -0,0 +1,615 @@
#!/usr/bin/env python3
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Phonemize Text and EnCodec Audio.
Usage example:
python3 bin/tokenizer.py \
--src_dir ./data/manifests --output_dir ./data/tokenized
"""
import argparse
import logging
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import torch.multiprocessing
from encodec import EncodecModel
from encodec.utils import convert_audio
from lhotse import CutSet, NumpyHdf5Writer
from lhotse.features import FeatureExtractor
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
from tqdm.auto import tqdm
from icefall.utils import get_executor
try:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials
except Exception:
pass
import re
from typing import Pattern
import numpy as np
from k2 import SymbolTable
# from valle.data import (
# AudioTokenConfig,
# AudioTokenExtractor,
# TextTokenizer,
# tokenize_text,
# )
# from valle.data.fbank import get_fbank_extractor
# from valle.utils import SymbolTable
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized files",
)
parser.add_argument(
"--text-extractor",
type=str,
default="espeak",
help="espeak or pypinyin or pypinyin_initials_finals",
)
parser.add_argument(
"--audio-extractor",
type=str,
default="Encodec",
help="Encodec or Fbank",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="dev-clean test-clean",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="libritts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--batch-duration",
type=float,
default=400.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Split the cut_set into multiple parts",
)
return parser.parse_args()
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = get_finals(py[0][:-1], strict=False) + py[0][-1]
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone] + [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
) -> None:
# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
return self.codec.decode(frames)
@dataclass
class AudioTokenConfig:
frame_shift: Seconds = 320.0 / 24000
num_quantizers: int = 8
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
return AudioTokenConfig(**data)
class AudioTokenExtractor(FeatureExtractor):
name = "encodec"
config_type = AudioTokenConfig
def __init__(self, config: Optional[Any] = None):
super(AudioTokenExtractor, self).__init__(config)
self.tokenizer = AudioTokenizer()
def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> np.ndarray:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if sampling_rate != self.tokenizer.sample_rate:
samples = convert_audio(
samples,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
if len(samples.shape) == 2:
samples = samples.unsqueeze(0)
else:
raise ValueError()
device = self.tokenizer.device
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
codes = encoded_frames[0][0] # [B, n_q, T]
if True:
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
assert abs(codes.shape[-1] - expected_num_frames) <= 1
codes = codes[..., :expected_num_frames]
return codes.cpu().squeeze(0).permute(1, 0).numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift
def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_quantizers
def pad_tensor_list(self, tensor_list, device, padding_value=0):
lengths = [tensor.shape[0] for tensor in tensor_list]
tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
tensor_list, batch_first=True, padding_value=padding_value
)
return padded_tensor, lengths
def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
samples = [wav.squeeze() for wav in samples]
device = self.tokenizer.device
samples, lengths = self.pad_tensor_list(samples, device)
samples = samples.unsqueeze(1)
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
if len(samples.shape) != 3:
raise ValueError()
if sampling_rate != self.tokenizer.sample_rate:
samples = [
convert_audio(
wav,
sampling_rate,
self.tokenizer.sample_rate,
self.tokenizer.channels,
)
for wav in samples
]
samples = torch.stack(samples, 0) # convert samples from list to tensor
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = self.tokenizer.encode(samples.detach().to(device))
encoded_frames = encoded_frames[0][0] # [B, n_q, T]
batch_codes = []
for b, length in enumerate(lengths):
codes = encoded_frames[b]
duration = round(length / sampling_rate, ndigits=12)
expected_num_frames = compute_num_frames(
duration=duration,
frame_shift=self.frame_shift,
sampling_rate=sampling_rate,
)
batch_codes.append(codes[..., :expected_num_frames])
return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
def main():
args = get_args()
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip()
if dataset_parts == "all": # LibriTTS
dataset_parts = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
else:
dataset_parts = dataset_parts.replace("-p", "").strip().split(" ")
assert len(dataset_parts) >= 1
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
text_tokenizer = None
if args.text_extractor:
text_tokenizer = TextTokenizer(backend=args.text_extractor)
audio_extractor = None
if args.audio_extractor:
if args.audio_extractor == "Encodec":
audio_extractor = AudioTokenExtractor(AudioTokenConfig())
else:
raise NotImplementedError(f"{args.audio_extractor}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
unique_symbols = set()
num_jobs = min(32, os.cpu_count())
logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}")
prefix = args.prefix
if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
# Split cut_set if split > 1
split = 1
if args.split > 1:
cut_sets = cut_set.split(args.split)
split = args.split
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
if args.audio_extractor:
if args.audio_extractor == "Encodec":
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
else:
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
if args.prefix.lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
]:
part = part.resample(24000)
assert args.prefix_lower() in [
"ljspeech",
"aishell",
"baker",
"wenetspeech4tts",
"libritts",
"libritts-r",
]
with torch.no_grad():
if (
torch.cuda.is_available()
and args.audio_extractor == "Encodec"
):
part = part.compute_and_store_features_batch(
extractor=audio_extractor,
storage_path=storage_path,
num_workers=num_jobs,
batch_duration=args.batch_duration,
collate=False,
overwrite=True,
storage_type=NumpyHdf5Writer,
)
else:
part = part.compute_and_store_features(
extractor=audio_extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=NumpyHdf5Writer,
)
# TextTokenizer
if args.text_extractor:
for c in tqdm(part):
if (
args.prefix == "baker"
and args.text_extractor == "labeled_pinyin"
):
phonemes = c.supervisions[0].custom["tokens"]["text"]
unique_symbols.update(phonemes)
else:
if args.prefix == "ljspeech":
text = c.supervisions[0].custom["normalized_text"]
text = text.replace(""", '"').replace(""", '"')
phonemes = tokenize_text(text_tokenizer, text=text)
elif args.prefix in [
"aishell",
"aishell2",
"wenetspeech4tts",
"libritts",
"libritts-r",
]:
phonemes = tokenize_text(
text_tokenizer, text=c.supervisions[0].text
)
if c.supervisions[0].custom is None:
c.supervisions[0].custom = {}
c.supervisions[0].normalized_text = c.supervisions[
0
].text
else:
raise NotImplementedError(f"{args.prefix}")
c.supervisions[0].custom["tokens"] = {"text": phonemes}
unique_symbols.update(phonemes)
c.tokens = phonemes
assert c.supervisions[
0
].normalized_text, "normalized_text is None"
# Save each part with an index if split > 1
cuts_filename = (
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
)
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if args.text_extractor:
unique_phonemes = SymbolTable()
for s in sorted(list(unique_symbols)):
unique_phonemes.add(s)
logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}")
unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols"
unique_phonemes.to_file(unique_phonemes_file)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in the manifests.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
"""
import argparse
from pathlib import Path
from lhotse import load_manifest_lazy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/tokenized"),
help="Path to the tokenized manifests.",
)
return parser.parse_args()
def main():
args = get_args()
manifest_dir = args.manifest_dir or Path("data/tokenized")
for part in ["train", "dev", "test"]:
print(f"## {part}")
cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz")
cuts.describe()
print("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
j=16
stage=2
stop_stage=2
dl_dir=$PWD/download
dataset_parts="-p Basic" # -p Premium for Premium dataset only
text_extractor="pypinyin_initials_finals" # default is espeak for English
audio_extractor="Encodec" # or Fbank
audio_feats_dir=data/tokenized
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "dl_dir: $dl_dir"
log "Stage 0: Download data"
huggingface-cli login
huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS
# Extract the downloaded data:
for folder in Standard Premium Basic; do
for file in "$dl_dir/$folder"/*.tar.gz; do
tar -xzvf "$file" -C "$dl_dir/$folder"
done
done
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare wenetspeech4tts manifest"
# We assume that you have downloaded the wenetspeech4tts corpus
# to $dl_dir/wenetspeech4tts
mkdir -p data/manifests
if [ ! -e data/manifests/.wenetspeech4tts.done ]; then
lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}"
touch data/manifests/.wenetspeech4tts.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Tokenize/Fbank wenetspeech4tts"
mkdir -p ${audio_feats_dir}
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
--text-extractor ${text_extractor} \
--audio-extractor ${audio_extractor} \
--batch-duration 2500 \
--prefix "wenetspeech4tts" \
--src-dir "data/manifests" \
--split 100 \
--output-dir "${audio_feats_dir}/${prefix}_baisc_split_100"
fi
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 13: Combine features for basic"
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz ]; then
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_baisc_split_100 -name "*.jsonl.gz")
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 3: Prepare wenetspeech4tts train/dev/test"
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
lhotse subset --first 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_dev.jsonl.gz
lhotse subset --last 400 \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_test.jsonl.gz
lhotse copy \
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
${audio_feats_dir}/cuts_train.jsonl.gz
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
fi
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi

View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1 @@
../local/compute_neural_codec_and_prepare_text_tokens.py

View File

@ -40,16 +40,17 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import torch import torch
import torchaudio import torchaudio
from icefall.utils import AttributeDict, str2bool from compute_neural_codec_and_prepare_text_tokens import (
from valle.data import (
AudioTokenizer, AudioTokenizer,
TextTokenizer, TextTokenizer,
tokenize_audio, tokenize_audio,
tokenize_text, tokenize_text,
) )
from valle.data.collation import get_text_token_collater from k2 import symbol_table
from valle.models import get_model from tokenizer import get_text_token_collater
from valle import VALLE
from icefall.utils import AttributeDict, str2bool
def get_args(): def get_args():
@ -70,21 +71,12 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--text", "--manifest",
type=str, type=str,
default="To get up and running quickly just follow the steps below.", default="",
help="Text to be synthesized.", help="prompt text\t prompt audio\ttarget text\ttarget audio",
) )
# model
# add_model_arguments(parser)
# parser.add_argument(
# "--text-tokens",
# type=str,
# default="data/tokenized/unique_text_tokens.k2symbols",
# help="Path to the unique text tokens file.",
# )
parser.add_argument( parser.add_argument(
"--text-extractor", "--text-extractor",
type=str, type=str,
@ -143,8 +135,19 @@ def load_model(checkpoint, device):
checkpoint = torch.load(checkpoint, map_location=device) checkpoint = torch.load(checkpoint, map_location=device)
args = AttributeDict(checkpoint) params = AttributeDict(checkpoint)
model = get_model(args) model = VALLE(
params.decoder_dim,
params.nhead,
params.num_decoder_layers,
norm_first=params.norm_first,
add_prenet=params.add_prenet,
prefix_mode=params.prefix_mode,
share_embedding=params.share_embedding,
nar_scale_factor=params.scale_factor,
prepend_bos=params.prepend_bos,
num_quantizers=params.num_quantizers,
)
missing_keys, unexpected_keys = model.load_state_dict( missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True checkpoint["model"], strict=True
@ -153,9 +156,7 @@ def load_model(checkpoint, device):
model.to(device) model.to(device)
model.eval() model.eval()
text_tokens = args.text_tokens return model, params.text_tokens
return model, text_tokens
@torch.no_grad() @torch.no_grad()
@ -181,9 +182,7 @@ def main():
encoded_frames = tokenize_audio(audio_tokenizer, audio_file) encoded_frames = tokenize_audio(audio_tokenizer, audio_file)
if False: if False:
samples = audio_tokenizer.decode(encoded_frames) samples = audio_tokenizer.decode(encoded_frames)
torchaudio.save( torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000)
f"{args.output_dir}/p{n}.wav", samples[0], 24000
)
audio_prompts.append(encoded_frames[0][0]) audio_prompts.append(encoded_frames[0][0])
@ -195,8 +194,8 @@ def main():
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
with open(args.text) as f: with open(args.text) as f:
for line in f: for line in f:
# fields = line.strip().split("\t") fields = line.strip().split("\t")
fields = line.strip().split(" ") # fields = line.strip().split(" ")
fields = [item for item in fields if item] fields = [item for item in fields if item]
assert len(fields) == 4 assert len(fields) == 4
prompt_text, prompt_audio, text, audio_path = fields prompt_text, prompt_audio, text, audio_path = fields
@ -209,11 +208,7 @@ def main():
] ]
) )
_, enroll_x_lens = text_collater( _, enroll_x_lens = text_collater(
[ [tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())]
tokenize_text(
text_tokenizer, text=f"{prompt_text}".strip()
)
]
) )
audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio)
@ -244,11 +239,7 @@ def main():
for n, text in enumerate(args.text.split("|")): for n, text in enumerate(args.text.split("|")):
logging.info(f"synthesize text: {text}") logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater( text_tokens, text_tokens_lens = text_collater(
[ [tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())]
tokenize_text(
text_tokenizer, text=f"{text_prompts} {text}".strip()
)
]
) )
# synthesis # synthesis
@ -263,11 +254,7 @@ def main():
enroll_x_lens = None enroll_x_lens = None
if text_prompts: if text_prompts:
_, enroll_x_lens = text_collater( _, enroll_x_lens = text_collater(
[ [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
tokenize_text(
text_tokenizer, text=f"{text_prompts}".strip()
)
]
) )
encoded_frames = model.inference( encoded_frames = model.inference(
text_tokens.to(device), text_tokens.to(device),
@ -280,13 +267,9 @@ def main():
) )
if audio_prompts != []: if audio_prompts != []:
samples = audio_tokenizer.decode( samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
[(encoded_frames.transpose(2, 1), None)]
)
# store # store
torchaudio.save( torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000)
f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000
)
else: # Transformer else: # Transformer
pass pass
@ -297,8 +280,6 @@ torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False) torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -3,9 +3,9 @@ from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
from k2 import SymbolTable from k2 import SymbolTable
class TextTokenCollater: class TextTokenCollater:
"""Collate list of text tokens """Collate list of text tokens
@ -52,15 +52,10 @@ class TextTokenCollater:
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens] self.idx2token = [token for token in unique_tokens]
def index( def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
self, tokens_list: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], [] seqs, seq_lens = [], []
for tokens in tokens_list: for tokens in tokens_list:
assert ( assert all([True if s in self.token2idx else False for s in tokens]) is True
all([True if s in self.token2idx else False for s in tokens])
is True
)
seq = ( seq = (
([self.bos_symbol] if self.add_bos else []) ([self.bos_symbol] if self.add_bos else [])
+ list(tokens) + list(tokens)
@ -103,10 +98,7 @@ class TextTokenCollater:
) )
tokens_lens = torch.IntTensor( tokens_lens = torch.IntTensor(
[ [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in tokens_seqs
]
) )
return tokens_batch, tokens_lens return tokens_batch, tokens_lens
@ -115,7 +107,5 @@ class TextTokenCollater:
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file) text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path) unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater( collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True)
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater return collater

View File

@ -49,10 +49,9 @@ import argparse
import copy import copy
import logging import logging
import os import os
from contextlib import nullcontext
import random import random
import warnings import warnings
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -60,6 +59,19 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse import CutSet
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam
from tokenizer import TextTokenCollater, get_text_token_collater
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from valle import VALLE
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -70,22 +82,10 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from lhotse import CutSet
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from optim import Eden, ScaledAdam
from valle import VALLE
from tokenizer import TextTokenCollater, get_text_token_collater
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
@ -95,6 +95,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim",
@ -159,6 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of Audio/Semantic quantization layers.", help="Number of Audio/Semantic quantization layers.",
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -363,7 +365,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
"""Return a dict containing training parameters. """Return a dict containing training parameters.
@ -568,9 +569,10 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device):
"""Parse batch data""" """Parse batch data"""
features = batch["features"].to(device) features = batch["features"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
if "tokens" not in batch: if "tokens" not in batch:
@ -590,6 +592,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
return features, features_lens, text_tokens, text_tokens_lens return features, features_lens, text_tokens, text_tokens_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -615,11 +618,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
( (
audio_features, audio_features,
audio_features_lens, audio_features_lens,
@ -684,9 +683,7 @@ def compute_validation_loss(
params.best_valid_loss = loss_value params.best_valid_loss = loss_value
if params.visualize: if params.visualize:
output_dir = Path( output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}"
)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
if isinstance(model, DDP): if isinstance(model, DDP):
model.module.visualize(predicts, batch, output_dir=output_dir) model.module.visualize(predicts, batch, output_dir=output_dir)
@ -777,26 +774,21 @@ def train_one_epoch(
is_training=True, is_training=True,
) )
# summary stats # summary stats
tot_loss = ( tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
tot_loss * (1 - 1 / params.reset_interval) 1 / params.reset_interval
) + loss_info * (1 / params.reset_interval) )
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
if params.batch_idx_train >= params.accumulate_grad_steps: if params.batch_idx_train >= params.accumulate_grad_steps:
if ( if params.batch_idx_train % params.accumulate_grad_steps == 0:
params.batch_idx_train % params.accumulate_grad_steps
== 0
):
if params.optimizer_name not in ["ScaledAdam", "Eve"]: if params.optimizer_name not in ["ScaledAdam", "Eve"]:
# Unscales the gradients of optimizer's assigned params in-place # Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual: # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
model.parameters(), 1.0
)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -825,7 +817,7 @@ def train_one_epoch(
model_cur=model, model_cur=model,
model_avg=model_avg, model_avg=model_avg,
) )
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -849,15 +841,13 @@ def train_one_epoch(
topk=params.keep_last_k, topk=params.keep_last_k,
rank=rank, rank=rank,
) )
if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or ( if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
@ -870,9 +860,7 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = ( cur_grad_scale = (
scaler._scale.item() scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0
if params.dtype in ["float16", "fp16"]
else 1.0
) )
logging.info( logging.info(
@ -897,12 +885,8 @@ def train_one_epoch(
"train/current_", "train/current_",
params.batch_idx_train, params.batch_idx_train,
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if params.dtype in ["float16", "fp16"]: if params.dtype in ["float16", "fp16"]:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", "train/grad_scale",
@ -922,9 +906,7 @@ def train_one_epoch(
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
f"Epoch {params.cur_epoch}, validation: {valid_info}"
)
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
) )
@ -1063,10 +1045,7 @@ def run(rank, world_size, args):
) )
else: else:
parameters_names.append( parameters_names.append(
[ [name_param_pair[0] for name_param_pair in model.named_parameters()]
name_param_pair[0]
for name_param_pair in model.named_parameters()
]
) )
optimizer = ScaledAdam( optimizer = ScaledAdam(
@ -1144,9 +1123,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler( scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -52,6 +52,7 @@ class _SeedWorkers:
def __call__(self, worker_id: int): def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id) fix_random_seed(self.seed + worker_id)
class TtsDataModule: class TtsDataModule:
""" """
DataModule for tts experiments. DataModule for tts experiments.
@ -301,9 +302,7 @@ class TtsDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz")
self.args.manifest_dir / "cuts_train.jsonl.gz"
)
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
@ -341,4 +340,4 @@ class TtsDataModule:
logging.info("About to get test-other cuts") logging.info("About to get test-other cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
) )

View File

@ -12,36 +12,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import math
import numbers
import random import random
from typing import Dict, Iterator, List, Tuple, Union from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from torch import Tensor
from icefall.utils import make_pad_mask from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
# from valle.data.input_strategies import PromptedFeatures from icefall.utils import make_pad_mask
# from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
# from valle.modules.transformer import (
# AdaptiveLayerNorm,
# LayerNorm,
# TransformerEncoder,
# TransformerEncoderLayer,
# )
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
from .visualizer import visualize from .visualizer import visualize
class PromptedFeatures: class PromptedFeatures:
def __init__(self, prompts, features): def __init__(self, prompts, features):
self.prompts = prompts self.prompts = prompts
self.features = features self.features = features
def to(self, device): def to(self, device):
return PromptedFeatures( return PromptedFeatures(self.prompts.to(device), self.features.to(device))
self.prompts.to(device), self.features.to(device)
)
def sum(self): def sum(self):
return self.features.sum() return self.features.sum()
@ -54,6 +54,7 @@ class PromptedFeatures:
def data(self): def data(self):
return (self.prompts, self.features) return (self.prompts, self.features)
class TokenEmbedding(nn.Module): class TokenEmbedding(nn.Module):
def __init__( def __init__(
self, self,
@ -114,9 +115,7 @@ class SinePositionalEmbedding(nn.Module):
x.size(1) - 1, -1, -1.0, dtype=torch.float32 x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1) ).unsqueeze(1)
else: else:
position = torch.arange( position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
0, x.size(1), dtype=torch.float32
).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.dim_model, 2, dtype=torch.float32) torch.arange(0, self.dim_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.dim_model) * -(math.log(10000.0) / self.dim_model)
@ -132,14 +131,17 @@ class SinePositionalEmbedding(nn.Module):
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output) return self.dropout(output)
class Transpose(nn.Identity): class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)""" """(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2) return input.transpose(1, 2)
_shape_t = Union[int, List[int], torch.Size] _shape_t = Union[int, List[int], torch.Size]
class MultiheadAttention(Module): class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper: from different representation subspaces as described in the paper:
@ -221,9 +223,7 @@ class MultiheadAttention(Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = ( self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.kdim == embed_dim and self.vdim == embed_dim
)
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
@ -234,12 +234,8 @@ class MultiheadAttention(Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter( self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
torch.empty((1, 1, embed_dim), **factory_kwargs) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
)
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
else: else:
self.bias_k = self.bias_v = None self.bias_k = self.bias_v = None
@ -396,20 +392,18 @@ class MultiheadAttention(Module):
) )
why_not_fast_path = "" why_not_fast_path = ""
if not is_batched: if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value: elif query is not key or key is not value:
# When lifting this restriction, don't forget to either # When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where # enforce that the dtypes all match or test cases where
# they don't! # they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif ( elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
self.in_proj_bias is not None
and query.dtype != self.in_proj_bias.dtype
):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif ( elif (
self.in_proj_weight is not None self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
and query.dtype != self.in_proj_weight.dtype
): ):
# this case will fail anyway, but at least they'll get a useful error message. # this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
@ -458,9 +452,7 @@ class MultiheadAttention(Module):
for x in tensor_args for x in tensor_args
] ]
): ):
why_not_fast_path = ( why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
"some Tensor argument is neither CUDA nor CPU"
)
elif torch.is_grad_enabled() and any( elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args] [x is not None and x.requires_grad for x in tensor_args]
): ):
@ -479,9 +471,7 @@ class MultiheadAttention(Module):
self.in_proj_bias, self.in_proj_bias,
self.out_proj.weight, self.out_proj.weight,
self.out_proj.bias, self.out_proj.bias,
key_padding_mask key_padding_mask if key_padding_mask is not None else attn_mask,
if key_padding_mask is not None
else attn_mask,
need_weights, need_weights,
average_attn_weights, average_attn_weights,
1 1
@ -506,9 +496,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)] query, key = [x.transpose(1, 0) for x in (query, key)]
value = key value = key
else: else:
query, key, value = [ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -722,14 +710,11 @@ class TransformerEncoderLayer(nn.Module):
self.activation = activation self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm: # if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm( # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
d_model, eps=layer_norm_eps, **factory_kwargs # else:
) if True:
else: norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if adaptive_layer_norm: if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -887,7 +872,6 @@ class TransformerEncoder(nn.Module):
return output return output
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
@ -898,9 +882,8 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
elif activation == "gelu": elif activation == "gelu":
return F.gelu return F.gelu
raise RuntimeError( raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
"activation should be relu/gelu, not {}".format(activation)
)
class VALLE(nn.Module): class VALLE(nn.Module):
"""It implements https://arxiv.org/abs/2301.02111 """It implements https://arxiv.org/abs/2301.02111
@ -1003,9 +986,7 @@ class VALLE(nn.Module):
num_layers=num_layers, num_layers=num_layers,
norm=LayerNorm(d_model) if norm_first else None, norm=LayerNorm(d_model) if norm_first else None,
) )
self.ar_predict_layer = nn.Linear( self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
d_model, NUM_AUDIO_TOKENS + 1, bias=False
)
self.ar_accuracy_metric = MulticlassAccuracy( self.ar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1, NUM_AUDIO_TOKENS + 1,
@ -1034,21 +1015,15 @@ class VALLE(nn.Module):
if add_prenet: if add_prenet:
self.nar_text_prenet = nn.Sequential( self.nar_text_prenet = nn.Sequential(
Transpose(), Transpose(),
nn.Conv1d( nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model), nn.BatchNorm1d(nar_d_model),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Conv1d( nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model), nn.BatchNorm1d(nar_d_model),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Conv1d( nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model), nn.BatchNorm1d(nar_d_model),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5), nn.Dropout(0.5),
@ -1092,9 +1067,7 @@ class VALLE(nn.Module):
adaptive_layer_norm=True, adaptive_layer_norm=True,
), ),
num_layers=int(num_layers * nar_scale_factor), num_layers=int(num_layers * nar_scale_factor),
norm=AdaptiveLayerNorm( norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
nar_d_model, norm=nn.LayerNorm(nar_d_model)
)
if norm_first if norm_first
else None, else None,
) )
@ -1105,10 +1078,7 @@ class VALLE(nn.Module):
] ]
) )
self.nar_stage_embeddings = nn.ModuleList( self.nar_stage_embeddings = nn.ModuleList(
[ [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
TokenEmbedding(nar_d_model, 1)
for i in range(num_quantizers - 1)
]
) )
if share_embedding: if share_embedding:
@ -1119,9 +1089,9 @@ class VALLE(nn.Module):
# We also share the parameters of the acoustic embedding layer and the output prediction layer, # We also share the parameters of the acoustic embedding layer and the output prediction layer,
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
for j in range(0, num_quantizers - 2): for j in range(0, num_quantizers - 2):
self.nar_predict_layers[ self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
j j + 2
].weight = self.nar_audio_embeddings[j + 2].weight ].weight
self.nar_accuracy_metric = MulticlassAccuracy( self.nar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1, NUM_AUDIO_TOKENS + 1,
@ -1192,13 +1162,9 @@ class VALLE(nn.Module):
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
for j in range(1, self.num_quantizers): for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j]( y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
codes[:, :prefix_len, j]
)
if j < nar_stage: if j < nar_stage:
y_emb += self.nar_audio_embeddings[j]( y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
codes[:, prefix_len:, j]
)
y_emb = torch.concat([y_prompts, y_emb], axis=1) y_emb = torch.concat([y_prompts, y_emb], axis=1)
elif self.prefix_mode in [2, 4]: elif self.prefix_mode in [2, 4]:
if self.prefix_mode == 2: if self.prefix_mode == 2:
@ -1211,9 +1177,7 @@ class VALLE(nn.Module):
y_prompts_codes.append( y_prompts_codes.append(
torch.clone(codes[b, start : start + prefix_len]) torch.clone(codes[b, start : start + prefix_len])
) )
codes[ codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
b, start : start + prefix_len, nar_stage
] = NUM_AUDIO_TOKENS
y_prompts_codes = torch.stack(y_prompts_codes, dim=0) y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
else: else:
prefix_len = y_prompts_codes.shape[1] prefix_len = y_prompts_codes.shape[1]
@ -1221,9 +1185,7 @@ class VALLE(nn.Module):
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
y_emb = self.nar_audio_embeddings[0](y) y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, self.num_quantizers): for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j]( y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
y_prompts_codes[..., j]
)
if j < nar_stage: if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[..., j]) y_emb += self.nar_audio_embeddings[j](codes[..., j])
y_emb = torch.concat([y_prompts, y_emb], axis=1) y_emb = torch.concat([y_prompts, y_emb], axis=1)
@ -1290,9 +1252,7 @@ class VALLE(nn.Module):
text = x text = x
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
y, targets = self.pad_y_eos( y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS)
codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
)
x_len = x_lens.max() x_len = x_lens.max()
@ -1408,21 +1368,16 @@ class VALLE(nn.Module):
xy_dec = xy_dec[:, x_lens.max() + prefix_len :] xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
if self.prefix_mode == 4: if self.prefix_mode == 4:
prefix_len = 0 # reset for Top10Accuracy metric prefix_len = 0 # reset for Top10Accuracy metric
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
0, 2, 1
)
# loss # loss
total_length = (y_lens).sum().type(torch.float32) total_length = (y_lens).sum().type(torch.float32)
total_loss += ( total_loss += F.cross_entropy(
F.cross_entropy( logits,
logits, targets,
targets, ignore_index=NUM_AUDIO_TOKENS,
ignore_index=NUM_AUDIO_TOKENS, reduction=reduction,
reduction=reduction, ) * (total_length / (total_length - prefix_len * x.shape[0]))
)
* (total_length / (total_length - prefix_len * x.shape[0]))
)
metrics["NarTop10Accuracy"] = ( metrics["NarTop10Accuracy"] = (
self.nar_accuracy_metric( self.nar_accuracy_metric(
F.pad( F.pad(
@ -1505,24 +1460,27 @@ class VALLE(nn.Module):
value=True, value=True,
) )
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
),
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_attn_mask = torch.concat( xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
[x_attn_mask_pad, y_attn_mask], dim=0 y.device
).to(y.device) )
xy_dec, _ = self.ar_decoder( xy_dec, _ = self.ar_decoder(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, mask=xy_attn_mask,
) )
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
ras=True ras = True
samples = topk_sampling( samples = topk_sampling(
logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_aware_sampling=ras,
preceding_tokens=y,
) )
if ( if (
@ -1531,9 +1489,7 @@ class VALLE(nn.Module):
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
): ):
if prompts.shape[1] == y.shape[1]: if prompts.shape[1] == y.shape[1]:
raise SyntaxError( raise SyntaxError("well trained model shouldn't reach here.")
"well trained model shouldn't reach here."
)
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
break break
@ -1545,9 +1501,7 @@ class VALLE(nn.Module):
return torch.stack(codes, dim=-1) return torch.stack(codes, dim=-1)
# Non-AR Decoders # Non-AR Decoders
y_emb = self.nar_audio_embeddings[0]( y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
y[:, int(self.ar_audio_prepend_bos) :]
)
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
enrolled_len = enroll_x_lens.max().item() enrolled_len = enroll_x_lens.max().item()
@ -1586,15 +1540,11 @@ class VALLE(nn.Module):
codes.append(samples) codes.append(samples)
if i < self.num_quantizers - 2: if i < self.num_quantizers - 2:
y_emb[:, :prefix_len] += embedding_layer( y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
prompts[..., i + 1]
)
y_emb[:, prefix_len:] += embedding_layer(samples) y_emb[:, prefix_len:] += embedding_layer(samples)
else: else:
for j in range(1, self.num_quantizers): for j in range(1, self.num_quantizers):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
prompts[..., j]
)
for i, (predict_layer, embedding_layer) in enumerate( for i, (predict_layer, embedding_layer) in enumerate(
zip( zip(
@ -1687,15 +1637,11 @@ class VALLE(nn.Module):
codes.append(samples) codes.append(samples)
if i < 6: if i < 6:
y_emb[:, :prefix_len] += embedding_layer( y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
prompts[..., i + 1]
)
y_emb[:, prefix_len:] += embedding_layer(samples) y_emb[:, prefix_len:] += embedding_layer(samples)
else: else:
for j in range(1, 8): for j in range(1, 8):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
prompts[..., j]
)
for i, (predict_layer, embedding_layer) in enumerate( for i, (predict_layer, embedding_layer) in enumerate(
zip( zip(
@ -1736,18 +1682,14 @@ def top_k_top_p_filtering(
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
if top_k > 0: if top_k > 0:
top_k = min( top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
max(top_k, min_tokens_to_keep), logits.size(-1)
) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p < 1.0: if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum( cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
F.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
@ -1755,9 +1697,7 @@ def top_k_top_p_filtering(
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
@ -1768,7 +1708,14 @@ def top_k_top_p_filtering(
return logits return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None): def topk_sampling(
logits,
top_k=10,
top_p=1.0,
temperature=1.0,
repetition_aware_sampling=False,
preceding_tokens=None,
):
# temperature: (`optional`) float # temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int # top_k: (`optional`) int
@ -1780,11 +1727,13 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
if temperature != 1.0: if temperature != 1.0:
logits = logits / temperature logits = logits / temperature
# Top-p/top-k filtering # Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) logits = top_k_top_p_filtering(
logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample # Sample
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
# print top 10 value and index # print top 10 value and index
print("top 10 value and index", torch.topk(probs, 10), top_p) print("top 10 value and index", torch.topk(probs, 10), top_p)
tokens = torch.multinomial(probs, num_samples=1) tokens = torch.multinomial(probs, num_samples=1)
if repetition_aware_sampling: if repetition_aware_sampling:
@ -1814,7 +1763,9 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware
probs = F.softmax(logits[i], dim=-1) probs = F.softmax(logits[i], dim=-1)
token_new = torch.multinomial(probs, num_samples=1) token_new = torch.multinomial(probs, num_samples=1)
print(f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}") print(
f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}"
)
print("probs", probs, logits.shape) print("probs", probs, logits.shape)
tokens[i] = token_new tokens[i] = token_new
else: else: