mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add eval seed tts
This commit is contained in:
parent
0f7ebb7ffb
commit
d2b473ad99
347
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
347
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
@ -0,0 +1,347 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
cpu:
|
||||
|
||||
s3tokenizer --data_dir xxx.scp \
|
||||
--device "cpu" \
|
||||
--output_dir "./" \
|
||||
--batch_size 32
|
||||
|
||||
gpu:
|
||||
|
||||
torchrun --nproc_per_node=8 --nnodes=1 \
|
||||
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||
`which s3tokenizer` --data_dir xxx.scp \
|
||||
--device "cuda" \
|
||||
--output_dir "./" \
|
||||
--batch_size 32
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import s3tokenizer
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from bigvganinference import BigVGANInference
|
||||
from datasets import load_dataset
|
||||
from lhotse.serialization import load_jsonl
|
||||
from llm_tts import LLMTTS
|
||||
from model.modules import MelSpec
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_model,
|
||||
get_tokenizer,
|
||||
interpolate_tokens,
|
||||
load_F5_TTS_pretrained_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--s3-tokenizer-name",
|
||||
required=False,
|
||||
type=str,
|
||||
choices=[
|
||||
"speech_tokenizer_v1",
|
||||
"speech_tokenizer_v1_25hz",
|
||||
"speech_tokenizer_v2_25hz",
|
||||
],
|
||||
help="model version",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
required=True,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model-name-or-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="model version",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="tokenizer dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="vocoder dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow-matching-model-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="flow matching model path",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels):
|
||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, mel_spectrogram):
|
||||
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
|
||||
"<|SPEECH_GENERATION_START|>"
|
||||
)
|
||||
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||
target_sample_rate = 24000
|
||||
hop_length = 256
|
||||
target_rms = 0.1
|
||||
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
prompt_text, target_text, prompt_audio_codes = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
item["prompt_audio_cosy2_tokens"],
|
||||
)
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Convert the text to speech: {prompt_text + target_text}",
|
||||
},
|
||||
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
|
||||
]
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
message,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
)
|
||||
|
||||
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
|
||||
|
||||
idx = input_ids.index(speech_generation_start_index)
|
||||
input_ids = input_ids[:idx] + prompt_audio_codes
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
# get flow matching model's prompt mel spectrogram
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
# Duration in mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
ref_mel_list.append(ref_mel)
|
||||
ref_mel_len_list.append(ref_mel_len)
|
||||
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
input_ids_list = [
|
||||
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"ids": ids,
|
||||
"ref_mel_batch": ref_mel_batch,
|
||||
"ref_mel_len_batch": ref_mel_len_batch,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
model = LLMTTS(
|
||||
model_dir=args.llm_model_name_or_path,
|
||||
tokenizer_dir=args.tokenizer_dir,
|
||||
s3_tokenizer_name=args.s3_tokenizer_name,
|
||||
device=device,
|
||||
)
|
||||
|
||||
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
|
||||
flow_matching_model = get_model(args).eval().to(device)
|
||||
_ = load_checkpoint(
|
||||
args.flow_matching_model_path,
|
||||
model=flow_matching_model,
|
||||
)
|
||||
|
||||
dataset = load_dataset(
|
||||
"yuekai/seed_tts_cosy2",
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
mel_spectrogram = MelSpec(
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24000,
|
||||
mel_spec_type="bigvgan",
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
for batch in dataloader:
|
||||
generate_codes = model.inference_batch(
|
||||
batch["input_ids"], batch["attention_mask"]
|
||||
)
|
||||
flow_matching_input_tokens, total_mel_lens = [], []
|
||||
for i, code in enumerate(generate_codes):
|
||||
flow_matching_input_token = interpolate_tokens(code)
|
||||
total_mel_len = len(flow_matching_input_token)
|
||||
flow_matching_input_tokens.append(flow_matching_input_token)
|
||||
total_mel_lens.append(total_mel_len)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
||||
"ref_mel_len_batch"
|
||||
].to(device)
|
||||
|
||||
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
|
||||
# pad tokens to the same length
|
||||
for i, tokens in enumerate(flow_matching_input_tokens):
|
||||
flow_matching_input_tokens[i] = torch.tensor(
|
||||
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||
)
|
||||
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
|
||||
generated, _ = flow_matching_model.sample(
|
||||
cond=ref_mels,
|
||||
text=flow_matching_input_tokens,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=16,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
no_ref_audio=False,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24_000
|
||||
# if ref_rms_list[i] < target_rms:
|
||||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
utt = batch["ids"][i]
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{utt}.wav",
|
||||
generated_wave,
|
||||
target_sample_rate,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
108
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
108
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2025 SparkAudio
|
||||
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class LLMTTS:
|
||||
"""
|
||||
LLM-TTS for text-to-speech generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Path,
|
||||
tokenizer_dir: Path,
|
||||
s3_tokenizer_name: str,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
Initializes the LLMTTS model with the provided configurations and device.
|
||||
|
||||
Args:
|
||||
model_dir (Path): Directory containing the model and config files.
|
||||
device (torch.device): The device (CPU/GPU) to run the model on.
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [
|
||||
"<|SPEECH_GENERATION_START|>"
|
||||
]
|
||||
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||
tokenizer.padding_side = "left"
|
||||
self.tokenizer = tokenizer
|
||||
self.assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_batch(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
temperature: float = 0.8,
|
||||
top_k: float = 50,
|
||||
top_p: float = 0.95,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
||||
|
||||
Args:
|
||||
text (str): The text input to be converted to speech.
|
||||
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
||||
prompt_text (str, optional): Transcript of the prompt audio.
|
||||
gender (str): female | male.
|
||||
pitch (str): very_low | low | moderate | high | very_high
|
||||
speed (str): very_low | low | moderate | high | very_high
|
||||
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
||||
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
||||
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated waveform as a tensor.
|
||||
"""
|
||||
# Generate speech using the model
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=input_ids.to(self.device),
|
||||
attention_mask=attention_mask.to(self.device),
|
||||
max_new_tokens=1024,
|
||||
do_sample=True,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
results = []
|
||||
generated_ids = generated_ids.cpu().tolist()
|
||||
for i in range(len(generated_ids)):
|
||||
assistant_index = generated_ids[i].index(self.assistant_index)
|
||||
padding_index = len(generated_ids[i])
|
||||
result = generated_ids[i][assistant_index + 2 :]
|
||||
result = [token - 151665 for token in result]
|
||||
result = [token for token in result if token >= 0]
|
||||
results.append(result)
|
||||
return results
|
Loading…
x
Reference in New Issue
Block a user