mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add eval seed tts
This commit is contained in:
parent
0f7ebb7ffb
commit
d2b473ad99
347
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
347
egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
cpu:
|
||||||
|
|
||||||
|
s3tokenizer --data_dir xxx.scp \
|
||||||
|
--device "cpu" \
|
||||||
|
--output_dir "./" \
|
||||||
|
--batch_size 32
|
||||||
|
|
||||||
|
gpu:
|
||||||
|
|
||||||
|
torchrun --nproc_per_node=8 --nnodes=1 \
|
||||||
|
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||||
|
`which s3tokenizer` --data_dir xxx.scp \
|
||||||
|
--device "cuda" \
|
||||||
|
--output_dir "./" \
|
||||||
|
--batch_size 32
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import s3tokenizer
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from bigvganinference import BigVGANInference
|
||||||
|
from datasets import load_dataset
|
||||||
|
from lhotse.serialization import load_jsonl
|
||||||
|
from llm_tts import LLMTTS
|
||||||
|
from model.modules import MelSpec
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
get_model,
|
||||||
|
get_tokenizer,
|
||||||
|
interpolate_tokens,
|
||||||
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="extract speech code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--s3-tokenizer-name",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"speech_tokenizer_v1",
|
||||||
|
"speech_tokenizer_v1_25hz",
|
||||||
|
"speech_tokenizer_v2_25hz",
|
||||||
|
],
|
||||||
|
help="model version",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||||
|
help="huggingface dataset split name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
required=True,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model-name-or-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="model version",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer-dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="tokenizer dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder-dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="vocoder dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flow-matching-model-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="flow matching model path",
|
||||||
|
)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def padded_mel_batch(ref_mels):
|
||||||
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||||
|
padded_ref_mels = []
|
||||||
|
for mel in ref_mels:
|
||||||
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||||
|
padded_ref_mels.append(padded_ref_mel)
|
||||||
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||||
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||||
|
return padded_ref_mels
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, mel_spectrogram):
|
||||||
|
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
)
|
||||||
|
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
target_sample_rate = 24000
|
||||||
|
hop_length = 256
|
||||||
|
target_rms = 0.1
|
||||||
|
input_ids_list, ref_mel_list, ref_mel_len_list = [], [], []
|
||||||
|
for i, item in enumerate(batch):
|
||||||
|
prompt_text, target_text, prompt_audio_codes = (
|
||||||
|
item["prompt_text"],
|
||||||
|
item["target_text"],
|
||||||
|
item["prompt_audio_cosy2_tokens"],
|
||||||
|
)
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Convert the text to speech: {prompt_text + target_text}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
message,
|
||||||
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_audio_codes = [c + 151665 for c in prompt_audio_codes]
|
||||||
|
|
||||||
|
idx = input_ids.index(speech_generation_start_index)
|
||||||
|
input_ids = input_ids[:idx] + prompt_audio_codes
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
|
||||||
|
# get flow matching model's prompt mel spectrogram
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
item["prompt_audio"]["array"],
|
||||||
|
item["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
# Duration in mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
ref_mel_list.append(ref_mel)
|
||||||
|
ref_mel_len_list.append(ref_mel_len)
|
||||||
|
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.tensor(input_ids_list, dtype=torch.int64)
|
||||||
|
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
||||||
|
ids = [item["id"] for item in batch]
|
||||||
|
|
||||||
|
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||||
|
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"ids": ids,
|
||||||
|
"ref_mel_batch": ref_mel_batch,
|
||||||
|
"ref_mel_len_batch": ref_mel_len_batch,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
model = LLMTTS(
|
||||||
|
model_dir=args.llm_model_name_or_path,
|
||||||
|
tokenizer_dir=args.tokenizer_dir,
|
||||||
|
s3_tokenizer_name=args.s3_tokenizer_name,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False)
|
||||||
|
vocoder = vocoder.eval().to(device)
|
||||||
|
|
||||||
|
flow_matching_model = get_model(args).eval().to(device)
|
||||||
|
_ = load_checkpoint(
|
||||||
|
args.flow_matching_model_path,
|
||||||
|
model=flow_matching_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
"yuekai/seed_tts_cosy2",
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
generate_codes = model.inference_batch(
|
||||||
|
batch["input_ids"], batch["attention_mask"]
|
||||||
|
)
|
||||||
|
flow_matching_input_tokens, total_mel_lens = [], []
|
||||||
|
for i, code in enumerate(generate_codes):
|
||||||
|
flow_matching_input_token = interpolate_tokens(code)
|
||||||
|
total_mel_len = len(flow_matching_input_token)
|
||||||
|
flow_matching_input_tokens.append(flow_matching_input_token)
|
||||||
|
total_mel_lens.append(total_mel_len)
|
||||||
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[
|
||||||
|
"ref_mel_len_batch"
|
||||||
|
].to(device)
|
||||||
|
|
||||||
|
max_len = max([len(tokens) for tokens in flow_matching_input_tokens])
|
||||||
|
# pad tokens to the same length
|
||||||
|
for i, tokens in enumerate(flow_matching_input_tokens):
|
||||||
|
flow_matching_input_tokens[i] = torch.tensor(
|
||||||
|
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||||
|
)
|
||||||
|
flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device)
|
||||||
|
generated, _ = flow_matching_model.sample(
|
||||||
|
cond=ref_mels,
|
||||||
|
text=flow_matching_input_tokens,
|
||||||
|
duration=total_mel_lens,
|
||||||
|
lens=ref_mel_lens,
|
||||||
|
steps=16,
|
||||||
|
cfg_strength=2.0,
|
||||||
|
sway_sampling_coef=-1,
|
||||||
|
no_ref_audio=False,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, gen in enumerate(generated):
|
||||||
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
|
|
||||||
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
|
target_rms = 0.1
|
||||||
|
target_sample_rate = 24_000
|
||||||
|
# if ref_rms_list[i] < target_rms:
|
||||||
|
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||||
|
utt = batch["ids"][i]
|
||||||
|
torchaudio.save(
|
||||||
|
f"{args.output_dir}/{utt}.wav",
|
||||||
|
generated_wave,
|
||||||
|
target_sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
108
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
108
egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) 2025 SparkAudio
|
||||||
|
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTTS:
|
||||||
|
"""
|
||||||
|
LLM-TTS for text-to-speech generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: Path,
|
||||||
|
tokenizer_dir: Path,
|
||||||
|
s3_tokenizer_name: str,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the LLMTTS model with the provided configurations and device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir (Path): Directory containing the model and config files.
|
||||||
|
device (torch.device): The device (CPU/GPU) to run the model on.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||||
|
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + [
|
||||||
|
"<|SPEECH_GENERATION_START|>"
|
||||||
|
]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.assistant_index = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference_batch(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
top_k: float = 50,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The text input to be converted to speech.
|
||||||
|
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
||||||
|
prompt_text (str, optional): Transcript of the prompt audio.
|
||||||
|
gender (str): female | male.
|
||||||
|
pitch (str): very_low | low | moderate | high | very_high
|
||||||
|
speed (str): very_low | low | moderate | high | very_high
|
||||||
|
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
||||||
|
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
||||||
|
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated waveform as a tensor.
|
||||||
|
"""
|
||||||
|
# Generate speech using the model
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
input_ids=input_ids.to(self.device),
|
||||||
|
attention_mask=attention_mask.to(self.device),
|
||||||
|
max_new_tokens=1024,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
generated_ids = generated_ids.cpu().tolist()
|
||||||
|
for i in range(len(generated_ids)):
|
||||||
|
assistant_index = generated_ids[i].index(self.assistant_index)
|
||||||
|
padding_index = len(generated_ids[i])
|
||||||
|
result = generated_ids[i][assistant_index + 2 :]
|
||||||
|
result = [token - 151665 for token in result]
|
||||||
|
result = [token for token in result if token >= 0]
|
||||||
|
results.append(result)
|
||||||
|
return results
|
Loading…
x
Reference in New Issue
Block a user