add eval seed tts

This commit is contained in:
root 2025-02-28 09:54:22 +00:00
parent 0f7ebb7ffb
commit d2b473ad99
2 changed files with 455 additions and 0 deletions

View File

@ -0,0 +1,347 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
cpu:
s3tokenizer --data_dir xxx.scp \
--device "cpu" \
--output_dir "./" \
--batch_size 32
gpu:
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --data_dir xxx.scp \
--device "cuda" \
--output_dir "./" \
--batch_size 32
"""
import argparse
import json
import os
from pathlib import Path
import s3tokenizer
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from bigvganinference import BigVGANInference
from datasets import load_dataset
from lhotse.serialization import load_jsonl
from llm_tts import LLMTTS
from model.modules import MelSpec
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from train import (
add_model_arguments,
get_model,
get_tokenizer,
interpolate_tokens,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--s3-tokenizer-name",
required=False,
type=str,
choices=[
"speech_tokenizer_v1",
"speech_tokenizer_v1_25hz",
"speech_tokenizer_v2_25hz",
],
help="model version",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument(
"--output_dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch_size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num_workers", type=int, default=4, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="model version",
)
parser.add_argument(
"--tokenizer-dir",
required=True,
type=str,
help="tokenizer dir",
)
parser.add_argument(
"--vocoder-dir",
required=True,
type=str,
help="vocoder dir",
)
parser.add_argument(
"--flow-matching-model-path",
required=True,
type=str,
help="flow matching model path",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
def data_collator(batch, tokenizer, mel_spectrogram):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
target_sample_rate = 24000
hop_length = 256
target_rms = 0.1
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
for i, item in enumerate(batch):
prompt_text, target_text, prompt_audio_codes = (
item["prompt_text"],
item["target_text"],
item["prompt_audio_cosy2_tokens"],
)
message = [
{
"role": "user",
"content": f"Convert the text to speech: {prompt_text + target_text}",
},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
]
input_ids = tokenizer.apply_chat_template(
message,
tokenize=True,
chat_template=TEMPLATE,
)
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + prompt_audio_codes
input_ids_list.append(input_ids)
# get flow matching model's prompt mel spectrogram
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
# Duration in mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
for input_ids in input_ids_list
]
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
ids = [item["id"] for item in batch]
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
model = LLMTTS(
model_dir=args.llm_model_name_or_path,
tokenizer_dir=args.tokenizer_dir,
s3_tokenizer_name=args.s3_tokenizer_name,
device=device,
)
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
vocoder = vocoder.eval().to(device)
flow_matching_model = get_model(args).eval().to(device)
_ = load_checkpoint(
args.flow_matching_model_path,
model=flow_matching_model,
)
dataset = load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
trust_remote_code=True,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
mel_spectrogram = MelSpec(
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24000,
mel_spec_type="bigvgan",
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
for batch in dataloader:
generate_codes = model.inference_batch(
batch["input_ids"], batch["attention_mask"]
)
flow_matching_input_tokens, total_mel_lens = [], []
for i, code in enumerate(generate_codes):
flow_matching_input_token = interpolate_tokens(code)
total_mel_len = len(flow_matching_input_token)
flow_matching_input_tokens.append(flow_matching_input_token)
total_mel_lens.append(total_mel_len)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
"ref_mel_len_batch"
].to(device)
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
# pad tokens to the same length
for i, tokens in enumerate(flow_matching_input_tokens):
flow_matching_input_tokens[i] = torch.tensor(
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
)
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
generated, _ = flow_matching_model.sample(
cond=ref_mels,
text=flow_matching_input_tokens,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
no_ref_audio=False,
seed=0,
)
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,108 @@
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
import re
from pathlib import Path
from typing import List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class LLMTTS:
"""
LLM-TTS for text-to-speech generation.
"""
def __init__(
self,
model_dir: Path,
tokenizer_dir: Path,
s3_tokenizer_name: str,
device: torch.device,
):
"""
Initializes the LLMTTS model with the provided configurations and device.
Args:
model_dir (Path): Directory containing the model and config files.
device (torch.device): The device (CPU/GPU) to run the model on.
"""
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float16,
device_map=device,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [
"<|SPEECH_GENERATION_START|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.assistant_index = tokenizer.convert_tokens_to_ids("assistant")
@torch.no_grad()
def inference_batch(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 0.8,
top_k: float = 50,
top_p: float = 0.95,
) -> torch.Tensor:
"""
Performs inference to generate speech from text, incorporating prompt audio and/or text.
Args:
text (str): The text input to be converted to speech.
prompt_speech_path (Path): Path to the audio file used as a prompt.
prompt_text (str, optional): Transcript of the prompt audio.
gender (str): female | male.
pitch (str): very_low | low | moderate | high | very_high
speed (str): very_low | low | moderate | high | very_high
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
top_k (float, optional): Top-k sampling parameter. Default is 50.
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
Returns:
torch.Tensor: Generated waveform as a tensor.
"""
# Generate speech using the model
generated_ids = self.model.generate(
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
max_new_tokens=1024,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
results = []
generated_ids = generated_ids.cpu().tolist()
for i in range(len(generated_ids)):
assistant_index = generated_ids[i].index(self.assistant_index)
padding_index = len(generated_ids[i])
result = generated_ids[i][assistant_index + 2 :]
result = [token - 151665 for token in result]
result = [token for token in result if token >= 0]
results.append(result)
return results